In [12]:
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.utils.data import Dataset
import numpy as np
import os

In [13]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
pwd

'/home/praveen/projects/DomainAdaptation/SuperResGANUnet'

In [15]:
class Params:
    def __init__(self):
        self.batchSize = 32
        self.Stage1imageSize = 64
        self.Stage2imageSize = 64
        self.LAMBDA = 10
        self.lr= 0.0002
        self.nc = 1
        self.nz = 100
        self.ngf = 64
        self.ndf = 64
        #for unet 
        self.nc_out = 1
        self.num_downsample = 4
        self.dataroot = '/home/praveen/projects/Speech/postnet_experiments/Postnet/speech_scripts/single_npy_dumps'
        self.metadata_dir = self.dataroot
        self.workers = 1
        self.restart = 'restart'
        self.cuda = True
        self.beta1 = 0.5
opt = Params()

In [16]:
opt.nc

1

In [17]:
#Image superresolution [64x64->128x128]
#We would like to do supervised training so as to produce 128x128 images from 64x64 
#Together with this, we add adversarial terms. This is a rather roundabout way of reimplementing pix2pix maybe

In [18]:
#Pull out things from https://github.com/pytorch/vision/tree/master/torchvision/datasets to create 
#a new dataset class
#We want a dataloader that can emit both 64x64 and 128x128 data at the same time with 'enumerate'
#Then we make the generator produce 128x128 taking in 64x64 as input which we then train

In [19]:
def make_dataset(dir):
    import os
    images = []
    d = os.path.expanduser(dir)
    print('d',d)
    
    if not os.path.exists(dir):
        print('path does not exist')

    for root, _, fnames in sorted(os.walk(d)):
        print('root', root)
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
            images.append(path)
    return images

In [30]:
def random_crop(source, target, cropsize=64):
    width = source.shape[0]
    height = source.shape[1]
    
    j = np.random.randint(width)
    i = np.random.randint(height)

        
    while (j+cropsize//2>width) or (j-cropsize//2<0):
        #print(j+cropsize,j-cropsize)
        j = np.random.randint(width)
        
    #print('end j',j)
    
    jplus = j+cropsize//2
    jminus = j-cropsize//2
        
    while (i+cropsize//2>height) or (i-cropsize//2<0):
        #print(i+cropsize,i-cropsize)
        i = np.random.randint(height)
        
    #print('end i', i)
    
    iplus = i+cropsize//2
    iminus = i-cropsize//2
        
    #print(jminus,jplus,iminus,iplus)
    
    cropped_source = source[jminus:jplus,iminus:iplus]
    cropped_target = target[jminus:jplus,iminus:iplus]
    
    
    return cropped_source, cropped_target
    

In [31]:
def get_mels(path,metadata_file):
    import os 
    source = []
    target = []
    
    with open(metadata_file,'r') as metafile:
        entries = [entry.split('\n')[0] for entry in metafile]
    metafile.close()
    
    for file in entries:
        src_file = np.load(os.path.join(path,'recon_'+file+'.npy'))
        tgt_file = np.load(os.path.join(path,'target_'+file+'.npy'))
        
        cropped_src, cropped_tgt = random_crop(src_file,tgt_file)
        
        source.append(cropped_src)
        target.append(cropped_tgt)
        
    return source, target
    

In [32]:
path = os.path.join(opt.dataroot, 'train')
metadata_file = os.path.join(opt.metadata_dir,'train.txt')

In [33]:
source, target = get_mels(path,metadata_file)

In [34]:
print(len(source), len(target))

9500 9500


In [None]:
def pil_loader(path):
    from PIL import Image
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [37]:
class ImageFolder(Dataset):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid_file (used to check of corrupt files)
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, data_path, metadata_file):
        #print('opt.dataroot',opt.dataroot)
        self.src_samples, self.tgt_samples = get_mels(data_path,metadata_file)
        #self.imgs = self.samples
        #self.transformA = transformA
        
        
    def __getitem__(self, index):
        """
        Args:
        index (int): Index
        Returns:
        tuple: (sample, target) where target is class_index of the target class.
        """
        #path = self.samples[index]
        #sample = pil_loader(path)
        
        #if self.transformA is not None:
        #sampleA = self.transformA(sample)
        
        return self.src_samples[index], self.tgt_samples[index]
       

    def __len__(self):
        return len(self.src_samples)


In [38]:
data_path = os.path.join(opt.dataroot, 'train')
metadata_file = os.path.join(opt.metadata_dir,'train.txt')

In [39]:
dataset = ImageFolder(data_path,metadata_file)

#Now we create a dataloader that dumps out both 64x64 and 128x128 when called with 'enumerate'
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

In [42]:
for i,(src,tgt) in enumerate(dataloader):
    print(i)
    print('src.size()',src.size())
    print('tgt.size()',tgt.size())

0
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
1
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
2
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
3
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
4
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
5
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
6
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
7
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
8
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
9
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
10
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
11
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
12
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
13
src.size() torch.Size([32, 64, 6

154
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
155
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
156
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
157
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
158
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
159
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
160
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
161
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
162
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
163
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
164
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
165
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
166
src.size() torch.Size([32, 64, 64])
tgt.size() torch.Size([32, 64, 64])
167
src.size

In [None]:
#dataloaderStage1, dataloaderStage2 = get_data_loaders(opt)

In [None]:
#%aimport model
#from model import G_Stage1
#from model import D_Stage1
#from model import G_Stage2
from model import D_Stage2
from model import get_unet_generator
#from model import UnetGenerator
#from model import D_Stage2_4x4

In [None]:
opt.nc

In [None]:
D2 = D_Stage2(opt.nc,opt.ndf)
print(opt.nc)
G2 = get_unet_generator(opt.nc, opt.nc_out, opt.num_downsample)

if opt.cuda:
    D2 = D2.cuda()
    G2 = G2.cuda()

In [None]:

print(G2)

In [None]:
x = torch.randn(1,3,64,64)
x = x.cuda()

In [None]:
from PIL import Image
import numpy as np
#x0 = Image.open('./data/source_9.jpg')
source = np.load('./data/recon_9.npy')
target = np.load('./data/target_9.npy')

In [None]:
print(source.shape[0],source.shape[1])

In [None]:
while(True):
    cropped_source, cropped_target = random_crop(source,target,64)
    if cropped_source.max()>0:
        break

In [None]:
cropped_source.max()

In [None]:
cimage = torch.from_numpy(cropped_source)

In [None]:
cimage.size()

In [None]:
cimage = cimage.unsqueeze(0).unsqueeze(0)

In [None]:
cimage.size()

In [None]:
c2 = cimage

In [None]:
c2.size()

In [None]:
c2 = c2.cuda()

In [None]:
y=G2(c2)

In [None]:
y.size()

In [None]:
print(y.size())

In [None]:
def display_image(image):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import figure
    figure(num=None, figsize=(4, 6), dpi=80, facecolor='w', edgecolor='k')
    plt.imshow(image,origin='lower')
    plt.show()

In [None]:
def write_image(image,output_file):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import figure
    figure(num=None, figsize=(50, 8), dpi=80, facecolor='w', edgecolor='k')
    plt.imshow(image,origin='lower')
    plt.savefig(output_file)

In [None]:
cout = y.detach().squeeze(0).squeeze(0).cpu().numpy()
display_image(cout)

In [None]:
display_image(cropped_source.T)

In [None]:
display_image(cropped_target.T)

In [None]:
display_image(source.T)

In [None]:
display_image(target.T)

In [None]:
write_image(source.T,'output_source')

In [None]:
write_image(target.T,'output_target')

In [None]:
%aimport train2
from train2 import run_trainer2

In [None]:
run_trainer2(dataloader, G2, D2, opt)

In [None]:
z= torch.randn(1,3,64,512)
z = z.cuda()

In [None]:
zout = G2(z)

In [None]:
zout.size()

In [None]:
def plot_mel(mel):
    import matplotlib.pyplot as plt
    import numpy as np
    import librosa.display
    import os
    plt.figure(figsize=(10, 4))
    #librosa.display.specshow(librosa.power_to_db(mel,ref=np.max),
    #                         y_axis='mel', fmax=8000, x_axis='time')
    librosa.display.specshow(mel,
                             y_axis='mel', fmax=8000, x_axis='time',cmap='magma')
    #plt.colorbar(format='%+2.0f dB')
    plt.title('Mel spectrogram ')
    plt.tight_layout()

    plt.show()
    plt.close()

In [None]:
plot_mel(source.T)

In [None]:
plot_mel(target.T)