In [132]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [133]:
import os

In [134]:
import tarfile

In [135]:
zip_file_path = r'/gdrive/MyDrive/Dataset/cats.tar.gz'

In [136]:
with tarfile.open(zip_file_path,'r:gz') as t:
    t.extractall('./catsdata')

In [137]:
data_path = r'/content/catsdata/cats'

In [138]:
len(os.listdir(data_path))

1

In [139]:
import torchvision.transforms as tt

In [140]:
data_tt = tt.Compose([tt.RandomCrop(64,4,padding_mode='reflect'),
                      tt.RandomHorizontalFlip(),
                      tt.ColorJitter(0.8,0.6,0.4),
                      tt.ToTensor()
                      ])

In [141]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

In [142]:
batch_size = 100


In [143]:
dataset = ImageFolder(data_path, transform= data_tt)

In [144]:
data_dl = DataLoader(dataset, batch_size=batch_size, pin_memory=True, num_workers=2)

In [145]:
img,label = dataset[0]
img.shape

torch.Size([3, 64, 64])

In [146]:
import torch.nn as nn

In [147]:
Dis = nn.Sequential(
    nn.Conv2d(3,12,kernel_size=3,stride=1,padding=1), #input 3x64x64
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 12x32x32
    
    nn.Conv2d(12,24,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 24x16x16
    
    nn.Conv2d(24,48, kernel_size=3, stride=1, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 48x8x8
    
    nn.Conv2d(48,96,kernel_size=3, stride=1, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2,2), #output 96x4x4
    
    nn.Conv2d(96,192, kernel_size=3, stride=1, padding=1),
    nn.LeakyReLU(),
    
    nn.AdaptiveAvgPool2d(1), #output 192x1x1
    
    nn.Flatten(),
    
    nn.Linear(192,1),
    
    nn.Sigmoid()
)

In [148]:
latent_size = 96
Gen = nn.Sequential(
    nn.ConvTranspose2d(latent_size, 192,kernel_size=4,stride=1, padding=1), #output 192x4x4
    nn.ReLU(),
    
    nn.ConvTranspose2d(192,96, kernel_size=4, stride=2, padding=1), #output 96x8x8
    nn.ReLU(),
    
    nn.ConvTranspose2d(96,48,kernel_size=4,stride=2,padding=1), #output 48x16x16
    nn.ReLU(),
    
    nn.ConvTranspose2d(48,24, kernel_size=4, stride=2, padding=1), #output 24x32x32
    nn.ReLU(),
    
    nn.ConvTranspose2d(24,3,kernel_size=4, stride=2, padding=1), #output 3x64x64
    nn.Tanh(),
)

In [149]:
import torch

In [150]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [151]:
Dis.to(device)
Gen.to(device)

Sequential(
  (0): ConvTranspose2d(96, 192, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): ConvTranspose2d(192, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (3): ReLU()
  (4): ConvTranspose2d(96, 48, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (5): ReLU()
  (6): ConvTranspose2d(48, 24, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ReLU()
  (8): ConvTranspose2d(24, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (9): Tanh()
)

In [152]:
loss_func = nn.BCELoss()
dis_opt_func = torch.optim.Adam(Dis.parameters(), lr=2e-4)
gen_opt_func = torch.optim.Adam(Gen.parameters(), lr=2e-4)

In [153]:
def disTraining(images):
    fake_labels = torch.zeros(100,1)
    real_labels = torch.ones(100,1)
    
    fake_labels = fake_labels.to(device)
    real_labels = real_labels.to(device)
    
    preds = Dis(images)
    dis_loss_real = loss_func(preds, real_labels)
    
    x = torch.randn(100,latent_size,1,1)
    x = x.to(device)
    
    fake_preds = Gen(x)
    fake_preds_dis = Dis(fake_preds)
    dis_loss_fake = loss_func(fake_preds_dis, fake_labels)
    
    dis_loss = dis_loss_fake + dis_loss_real
    
    dis_opt_func.zero_grad()
    gen_opt_func.zero_grad()
    
    dis_loss.backward()
    
    dis_opt_func.step()
    
    return dis_loss

In [154]:
def genTraining():
    x = torch.randn(100,latent_size,1,1)
    x = x.to(device)
    
    labels = torch.ones(100,1)
    labels = labels.to(device)
    
    fake_preds = Gen(x)
    fake_preds_dis = Dis(fake_preds)
    
    gen_loss = loss_func(fake_preds_dis, labels)
    
    dis_opt_func.zero_grad()
    gen_opt_func.zero_grad()
    
    gen_loss.backward()
    
    gen_opt_func.step()
    
    return gen_loss

In [155]:
from torchvision.utils import save_image
import  os

In [156]:
ran_vec = torch.randn(100,latent_size,1,1).to(device)

def save_after_epoch(i):
    out =Gen(ran_vec)
    
    name = f'saved_image_after_epoch{i+1}.png'
    print(f'saving {name}')
    save_image(out, os.path.join('./sample_data', name), nrow = 10)

In [157]:
def fit(num_epochs):
    for epoch in range(num_epochs):
        for images,_ in data_dl:
            images.to(device)
            
            d_loss = disTraining(images)
            g_loss = genTraining()
        print(f'd_loss: {d_loss} and g_loss: {g_loss}')
        save_after_epoch(epoch)

In [158]:
fit(100)

ValueError: ignored