In [None]:
!pip install torch

In [None]:
!pip install torchvision

In [None]:
import sys
print(sys.version)
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.utils as vutils
print(torch.__version__) # 1.0.1

%matplotlib inline
import matplotlib.pyplot as plt

def show_imgs(x, new_fig=True):
    grid = vutils.make_grid(x.detach().cpu(), nrow=8, normalize=True, pad_value=0.3)
    grid = grid.transpose(0,2).transpose(0,1)
    if new_fig:
        plt.figure()
    plt.imshow(grid.numpy())

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, inp_dim=262144):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(inp_dim, 512)
        self.nonlin1 = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(512, 1)
    def forward(self, x):
        x = x.view(x.size(0), 262144)
        x = torch.FloatTensor(x)
        h = self.nonlin1(self.fc1(x))
        out = self.fc2(h)
        out = torch.sigmoid(out)
        return out

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim, 512)
        self.nonlin1 = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(512, 262144)
    def forward(self, x):
        h = self.nonlin1(self.fc1(x))
        out = self.fc2(h)
        out = torch.tanh(out)
        out = out.view(out.size(0), 1, 512, 512)
        return out
    

In [None]:
#probably should change from initial default, but we can fix for final proj not nec this milestone
D = Discriminator()
print(D)
G = Generator()
print(G)

In [None]:
samples = torch.randn(5, 1, 28, 28)
D(samples)

In [None]:
for name, p in D.named_parameters():
    print(name, p.shape)

In [None]:
for name, p in G.named_parameters():
    print(name, p.shape)

In [None]:
z = torch.randn(2, 100)
show_imgs(G(z))

loading data & forward pass

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MRIDataset(Dataset):
    def __init__(self, low_res_file, high_res_file):
        """
        Args:
            low_res_file, high_res_file: text files containing paths to
            the low resolution and high resolution images
        """
        #lrf = np.loadtxt(low_res_file, dtype = str)
        hrf = np.loadtxt(high_res_file, dtype = str)
        self.imgs = hrf
        #self.imgs = zip(lrf, hrf)
        #print(len(list(self.imgs)))
        
    def __len__(self):
        return len(list(self.imgs))

    def __getitem__(self, idx):
        #low_res_path, high_res_path = list(self.imgs)[idx]
        #print("HI")
        #lri, hri = np.load(low_res_path), np.load(high_res_path)
        high_res_path = list(self.imgs)[idx]
        hri = np.load(high_res_path)
        #return lri, hri
        return hri

def get_loader(low_res_file, high_res_file, bs, shuffle):
    mri_dataset = MRIDataset(low_res_file, high_res_file)
    return DataLoader(mri_dataset, bs, shuffle)

dataloader = get_loader("/Users/ryanli/Desktop/lr_file.txt", "/Users/ryanli/Desktop/hr_file.txt", 64, False)
#print(len(dataloader))
import itertools

#print(len(dataloader))

In [None]:
#temp dataset - change this to dicom, but good bc made greyscle so sim to mri
dataset = MRIDataset("/Users/ryanli/Desktop/lr_file.txt", "/Users/ryanli/Desktop/hr_file.txt")

In [None]:
ix=1000
x = dataset.__getitem__(ix)
x = torch.FloatTensor(x)
print(x.view(x.size(0), 512))
plt.matshow(np.squeeze(x), cmap=plt.cm.gray)
plt.colorbar()
plt.savefig("low_res_example.png")

y = torch.FloatTensor(y)
plt.matshow(np.squeeze(y), cmap=plt.cm.gray)
plt.colorbar()
plt.savefig("high_res_example.png")

In [None]:
#import itertools
#Dscore = D(x)
#print(len(iter(dataloader)))
#lr_batch, hr_batch = next(itertools.cycle(dataloader))#.next() #minibatch
#lr_batch.shape
#D(lr_batch)
#show_imgs(lr_batch)

back to gans

In [None]:
#optimizers
optimizerD = torch.optim.SGD(D.parameters(), lr=0.01)
optimizerG = torch.optim.SGD(G.parameters(), lr=0.01)

In [None]:
#other notebook said taking this out improved acc - look into alter?
criterion = nn.BCELoss()
x_real, _ = enumerate(dataloader).next()
print(x_real)
lab_real = torch.ones(64, 1)
lab_fake = torch.zeros(64, 1)
optimizerD.zero_grad()

D_x = D(x_real)
lossD_real = criterion(D_x, lab_real)

z = torch.randn(64, 100) #random noise but ?
x_gen = G(z).detach()
D_G_z = D(x_gen)
lossD_fake = criterion(D_G_z, lab_fake)

lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()

# print(D_x.mean().item(), D_G_z.mean().item())
optimizerG.zero_grad()

z = torch.randn(64, 100) 
D_G_z = D(G(z))
lossG = criterion(D_G_z, lab_real)

lossG.backward()
optimizerG.step()

print(D_G_z.mean().item())

In [None]:
criterion = nn.BCELoss()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Device: ', device)
# Re-initialize D, G - not sure why?
D = Discriminator().to(device)
G = Generator().to(device)
# other paper said rmsprops was better than adam - not sure how it stacks against sgd - look into for final
optimizerD = torch.optim.SGD(D.parameters(), lr=0.2)
optimizerG = torch.optim.SGD(G.parameters(), lr=0.2)
# optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002)
# optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002)
lab_real = torch.ones(64, 1, device=device)
lab_fake = torch.zeros(64, 1, device=device)


collect_x_gen = []
fixed_noise = torch.randn(64, 100, device=device)
fig = plt.figure() 
plt.ion()

for epoch in range(2): # 10 epochs
    #for i, (lri, hri) in enumerate(dataloader, 0):
    for i in range(0, 21):
        x_real = next(iter(dataloader)).float()
        x_real = x_real.to(device)
        optimizerD.zero_grad()
        x_real = torch.FloatTensor(x_real)
        
        D_x = D(x_real)
        lossD_real = criterion(D_x, lab_real)

        z = torch.randn(64, 100, device=device)
        x_gen = G(z).detach()
        D_G_z = D(x_gen)
        lossD_fake = criterion(D_G_z, lab_fake)

        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimizerD.step()
        optimizerG.zero_grad()

        z = torch.randn(64, 100, device=device) 
        x_gen = G(z)
        D_G_z = D(x_gen)
        lossG = criterion(D_G_z, lab_real)

        lossG.backward()
        optimizerG.step()
        if i % 100 == 0:
            x_gen = G(fixed_noise)
            show_imgs(x_gen, new_fig=False)
            fig.canvas.draw()
            print('e{}.i{}/{} last mb D(x)={:.4f} D(G(z))={:.4f}'.format(
                epoch, i, 21, D_x.mean().item(), D_G_z.mean().item()))
    #epoch end
    x_gen = G(fixed_noise)
    collect_x_gen.append(x_gen.detach().clone())

In [None]:
for x_gen in collect_x_gen:
    show_imgs(x_gen)