In [None]:
'Hello World!'

'Hello World!'

In [None]:
# only use this when the folder is not created in the drive and also you need to mount
from google.colab import drive
import os

drive.mount('/content/drive', force_remount=True)

drive_path = "/content/drive/MyDrive/GAN_Project"
os.makedirs(drive_path, exist_ok=True)

# testing file
test_file = os.path.join(drive_path, "connection_test.txt")
with open(test_file, "w") as f:
    f.write("Drive is connected!")

if os.path.exists(test_file):
    print("‚úÖ DRIVE VERIFIED: Connection is live. You can start training.")
else:
    raise RuntimeError("‚ùå DRIVE FAILURE: Files are not writing to Drive. Check permissions.")

In [None]:
# Mount and veriy only when the folder is already created in the drive
drive.mount('/content/drive', force_remount=True)
drive_save_path = "/content/drive/MyDrive/GAN_Project"
checkpoint_path = os.path.join(drive_save_path, "checkpoint_epoch_9.pth")

In [None]:
%%writefile model.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class MappingNetwork(nn.Module):
    def __init__(self, z_dim=512, w_dim=512):
        super().__init__()
        layers = []
        for _ in range(8):
            layers.append(nn.Linear(w_dim, w_dim))
            layers.append(nn.LeakyReLU(0.2)) # Standard StyleGAN leakiness
        self.mapping = nn.Sequential(*layers)
    
    def forward(self, z):
        # Normalize the input latent vector
        z = z / (z.norm(dim=1, keepdim=True) + 1e-8)
        return self.mapping(z)

class StyleConv(nn.Module):
    def __init__(self, in_ch, out_ch, w_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_ch, in_ch, 3, 3))
        self.style = nn.Linear(w_dim, in_ch)
        self.noise_strength = nn.Parameter(torch.zeros(1))
        self.bias = nn.Parameter(torch.zeros(out_ch))
    
    def forward(self, x, w, noise):
        b, c, h, w_ = x.shape
        # Weight demodulation/modulation logic
        style = self.style(w).view(b, 1, c, 1, 1)
        weight = self.weight.unsqueeze(0) * style
        weight = weight.view(-1, c, 3, 3)

        x = x.view(1, -1, h, w_)
        x = F.conv2d(x, weight, padding=1, groups=b)
        x = x.view(b, -1, h, w_)

        x = x + self.noise_strength * noise
        return x + self.bias.view(1, -1, 1, 1)

class Generator(nn.Module):
    def __init__(self, z_dim=512, w_dim=512):
        super().__init__()
        self.mapping = MappingNetwork(z_dim, w_dim)
        self.const = nn.Parameter(torch.randn(1, 512, 4, 4))
        self.layers = nn.ModuleList([
            StyleConv(512, 512, w_dim),
            StyleConv(512, 256, w_dim),
            StyleConv(256, 128, w_dim),
            StyleConv(128, 64, w_dim),
        ])
        self.to_rgb = nn.Conv2d(64, 3, 1)

    def forward(self, z):
        w = self.mapping(z)
        x = self.const.repeat(z.size(0), 1, 1, 1)

        for layer in self.layers:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)
            x = layer(x, w, noise)
            x = F.leaky_relu(x, 0.2)

        return torch.tanh(self.to_rgb(x))

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 4, 2, 1),
                nn.LeakyReLU(0.2)
            )
        self.net = nn.Sequential(
            block(3, 64),
            block(64, 128),
            block(128, 256),
            block(256, 512),
            nn.Flatten(),
            nn.Linear(512*4*4, 1)
        )

    def forward(self, x):
        return self.net(x)

Writing model.py


In [5]:
%%writefile utils.py
import torch
import os
from glob import glob
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root, size=64):
        # Fixed the glob import usage
        self.files = glob(os.path.join(root, "**/*.jpg"), recursive=True)
        
        if len(self.files) == 0:
            print(f"‚ö†Ô∏è Warning: No .jpg files found in {root}. Check path or extensions.")
            
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3) # Scales to [-1, 1]
        ])
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

def Gradient_Penalty(D, real, fake):
    # CRITICAL: We need to know which device to use
    device = real.device 
    
    alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
    interp = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    out = D(interp)

    grads = torch.autograd.grad(
        outputs=out,
        inputs=interp,
        grad_outputs=torch.ones_like(out),
        create_graph=True,
        retain_graph=True,
    )[0]

    # Calculate the norm correctly across the image dimensions
    grads = grads.view(grads.size(0), -1)
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp

Writing utils.py


In [6]:
import os
import shutil
from kagglehub import kagglehub

os.environ["KAGGLE_USERNAME"] = "tejaskumarvurs"
os.environ["KAGGLE_KEY"] = "80fbc5d540819df3b4666ae5df969af9"

path = kagglehub.dataset_download("tejaskumarvurs/gen-ai-animal-dataset")

dest = "/content/animal_data"
if not os.path.exists(dest):
    shutil.copytree(path, dest)

print(f"‚úÖ Dataset is ready at: {dest}")

Downloading from https://www.kaggle.com/api/v1/datasets/download/tejaskumarvurs/gen-ai-animal-dataset?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8.91G/8.91G [01:26<00:00, 110MB/s] 

Extracting files...





‚úÖ Dataset is ready at: /content/animal_data


In [7]:
!find / -name "gen-ai-animal-dataset" -type d 2>/dev/null

/root/.cache/kagglehub/datasets/tejaskumarvurs/gen-ai-animal-dataset


In [8]:
# Force a link from the hidden system path to your visible content folder
!ln -s /root/.cache/kagglehub/datasets/tejaskumarvurs/gen-ai-animal-dataset/versions/1 /content/dataset

# Now check if Colab can 'see' into that shortcut
!ls /content/dataset | head -n 5

!ls /content/dataset/Camel | head -n 5

Bear
Brown bear
Bull
Butterfly
Camel
01d5030b1bb698d6.jpg
0215b972cb19e575.jpg
02db256a75c9419a.jpg
035205673c0ae617.jpg
049a26d67bd3192d.jpg


In [9]:
DATA_PATH = "/content/dataset"

def count_files(directory):
    return sum([len(files) for r, d, files in os.walk(directory)])

print(f"Total files accessible for T4 training: {count_files(DATA_PATH)}")

Total files accessible for T4 training: 29071


In [10]:
path = "/content/dataset"

if os.path.exists(path):
    contents = os.listdir(path)
    print(f"‚úÖ Path exists!")
    print(f"Items inside '{path}': {contents[:5]}") # Shows first 5 items
    
    # Check if the first item is a file or a folder
    if len(contents) > 0:
        first_item = os.path.join(path, contents[0])
        if os.path.isdir(first_item):
            print("üìÅ Found subfolders (Classes). Use datasets.ImageFolder(path)")
        else:
            print("üñºÔ∏è Found direct files. Use a custom Dataset class.")
else:
    print(f"‚ùå Path NOT found: {path}")
    print("Checking /content/ to see what is actually there:")
    print(os.listdir("/content/"))

‚úÖ Path exists!
Items inside '/content/dataset': ['Crab', 'Sea turtle', 'Parrot', 'Elephant', 'Koala']
üìÅ Found subfolders (Classes). Use datasets.ImageFolder(path)


In [11]:
import importlib
import utils
import model

# This forces the notebook to read the new versions of your files, these are required when model or utils have been changed.
importlib.reload(utils)
importlib.reload(model)

from model import Generator, Discriminator
from utils import ImageDataset, Gradient_Penalty

In [12]:
import os
import torch
from glob import glob
from torchvision.utils import save_image
from torch.utils.data import DataLoader

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = "/content/dataset" 
    
    # Ensure local and Drive folders exist
    os.makedirs("samples", exist_ok=True)
    drive_save_path = "/content/drive/MyDrive/GAN_Project"
    os.makedirs(drive_save_path, exist_ok=True)

    dataset = ImageDataset(path, size=64)
    print(f"‚úÖ Dataset loaded with {len(dataset)} images.")
    loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)

    G = Generator().to(device)
    D = Discriminator().to(device)

    g_opt = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.0, 0.99))
    d_opt = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.0, 0.99))

    # Initialize g_loss so the print statement doesn't crash
    g_loss = torch.tensor(0.0)

    for epoch in range(50):
        for i, real in enumerate(loader):
            real = real.to(device)
            z = torch.randn(real.size(0), 512).to(device)
            fake = G(z)

            # Discriminator Update
            d_loss = D(fake).mean() - D(real).mean()
            gp = Gradient_Penalty(D, real, fake)
            d_total = d_loss + 10 * gp

            d_opt.zero_grad()
            d_total.backward()
            d_opt.step()

            # Generator Update (Every 5 steps)
            if i % 5 == 0:
                # We need to re-generate or detach if we used fake above
                g_loss = -D(G(z)).mean()
                g_opt.zero_grad()
                g_loss.backward()
                g_opt.step()

        # --- PERIODIC SAVE (Safety Net) ---
        if (epoch + 1) % 10 == 0:
            ckpt_path = os.path.join(drive_save_path, f"checkpoint_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'G_state': G.state_dict(),
                'D_state': D.state_dict(),
                'g_opt': g_opt.state_dict(),
                'd_opt': d_opt.state_dict()
            }, ckpt_path)
            print(f"üíæ Checkpoint saved: {ckpt_path}")

        # Save Sample Images
        with torch.no_grad():
            samples = G(torch.randn(16, 512).to(device))
            save_image(samples, f"samples/epoch_{epoch}.png", normalize=True)

        print(f"Epoch {epoch} | D: {d_total.item():.3f} | G: {g_loss.item():.3f}")

    # Final Save
    torch.save(G.state_dict(), os.path.join(drive_save_path, "generator_final.pth"))
    print("üèÅ Training complete. Final model saved to Drive!")

In [None]:
if __name__ == "__main__":
    train()

‚úÖ Dataset loaded with 29071 images.


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0 | D: -126.259 | G: -58.254
Epoch 1 | D: -44.432 | G: -135.634
Epoch 2 | D: -72.408 | G: -189.737
Epoch 3 | D: -88.329 | G: -176.951
Epoch 4 | D: -1.748 | G: 79.333
Epoch 5 | D: -6.057 | G: -17.648
Epoch 6 | D: -12.065 | G: -42.308
Epoch 7 | D: -15.027 | G: -48.577
Epoch 8 | D: -8.847 | G: -107.692
üíæ Checkpoint saved: /content/drive/MyDrive/GAN_Project/checkpoint_epoch_9.pth
Epoch 9 | D: -10.874 | G: -73.063
Epoch 10 | D: -13.779 | G: -42.733
Epoch 11 | D: -9.729 | G: -22.208
Epoch 12 | D: -3.548 | G: -11.995
Epoch 13 | D: -7.448 | G: -24.200
Epoch 14 | D: -6.425 | G: -6.672
Epoch 15 | D: -4.407 | G: -39.681
Epoch 16 | D: -6.710 | G: -16.082
