In [None]:
import os
import math
import time
import random
import warnings
warnings.filterwarnings("ignore")

import yaml
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sn
import albumentations as A
import torch
from torch.utils import data as torch_data
from torch import nn as torch_nn
from torch.nn import functional as torch_functional
import torchvision
from sklearn import metrics as sk_metrics
from sklearn import model_selection as sk_model_selection
import neptune

In [None]:
config = {
    "seed": 42,
    
    "valid_size": 0.3,
    "image_size": (512, 512),
    
    "train_batch_size": 4,
    "valid_batch_size": 1,
    "test_batch_size": 1,
    
    "model": "mobilenet_v2",
    
    "max_epochs": 50,
    "model_save_path": "model-best.torch",
    "patience_stop": 3,
    
    "optimizer": "adam",
    "adam_lr": 0.0001,
    
    "criterion": "cross_entropy",
}

In [None]:
# The directory to the dataset
BASE_DIR = '/Users/tejakolla/Documents/sem-2/Deep_learning/project-1-teja2002/archive'
PATH_INDEX = os.path.join(BASE_DIR, "index.csv")
PATH_TEST = os.path.join(BASE_DIR, "test.csv")
PATH_METADATA = os.path.join(BASE_DIR, "metadata.csv")

In [None]:
# Read information about dataset
df = pd.read_csv(PATH_INDEX)

tmp_train, tmp_valid = sk_model_selection.train_test_split(
    df, 
    test_size=config["valid_size"], 
    random_state=config["seed"], 
    stratify=df['class_id'],
)


def get_paths_and_targets(tmp_df):
    # Get file paths
    paths = tmp_df["path"].values
    # Create full paths (base dir + concrete file name)
    paths = list(
        map(
            lambda x: os.path.join(BASE_DIR, x), 
            paths
        )
    )
    # Get labels
    targets = tmp_df["class_id"].values
    
    return paths, targets


# Get train file paths and targets
train_paths, train_targets = get_paths_and_targets(tmp_train)

# Get valid file paths and targets
valid_paths, valid_targets = get_paths_and_targets(tmp_valid)

df_test = pd.read_csv(PATH_TEST)
# Get test file paths and targets
test_paths, test_targets = get_paths_and_targets(df_test)

In [None]:
# Calculate the total number of classes in the dataset (len of unique labels in data)
df_metadata = pd.read_csv(PATH_METADATA, encoding='ISO-8859-1')
n_classes = df_metadata.shape[0]
print("Number of classes: ", n_classes)

In [None]:
class DataRetriever(torch_data.Dataset):
    def __init__(
        self, 
        paths, 
        targets, 
        image_size,
        transforms=None,
        preprocess=None,
    ):
        self.paths = paths
        self.targets = targets
        self.image_size = image_size
        self.transforms = transforms
        self.preprocess = preprocess
          
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, index):
        img = cv2.imread(self.paths[index])
        img = cv2.resize(img, self.image_size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transforms:
            img = self.transforms(image=img)['image']
        if self.preprocess:
            img = self.preprocess(img)
        
        y = torch.tensor(self.targets[index] - 1, dtype=torch.long)
            
        return {'X': img, 'y': y}

In [None]:
def get_train_transforms():
    return A.Compose(
        [
            A.Rotate(limit=30, border_mode=cv2.BORDER_REPLICATE, p=0.5),
            A.HorizontalFlip(p=0.5),
        ], 
        p=1.0
    )

def get_preprocess():
    return torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ]
    )

In [None]:
train_data_retriever = DataRetriever(
    train_paths, 
    train_targets, 
    image_size=config["image_size"],
    transforms=get_train_transforms(),
    preprocess=get_preprocess(),
)

valid_data_retriever = DataRetriever(
    valid_paths, 
    valid_targets, 
    image_size=config["image_size"],
    preprocess=get_preprocess(),
)

test_data_retriever = DataRetriever(
    test_paths, 
    test_targets, 
    image_size=config["image_size"],
    preprocess=get_preprocess(),
)

In [None]:
train_loader = torch_data.DataLoader(
    train_data_retriever,
    batch_size=config["train_batch_size"],
    shuffle=True,
)

valid_loader = torch_data.DataLoader(
    valid_data_retriever, 
    batch_size=config["valid_batch_size"],
    shuffle=False,
)

test_loader = torch_data.DataLoader(
    test_data_retriever, 
    batch_size=config["test_batch_size"],
    shuffle=False,
)

In [None]:
def denormalize_image(image):
    # Denormalize and ensure the values are clipped within the valid range
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    return image.clip(0, 1)  # Clip values to the [0, 1] range

# Let's visualize some batches of the train data
fig = plt.figure(figsize=(16, 16))
for i_batch, batch in enumerate(train_loader):
    images, labels = batch["X"], batch["y"]
    for i in range(len(images)):
        plt.subplot(4, 4, 4 * i_batch + i + 1)
        plt.imshow(denormalize_image(images[i].permute(1, 2, 0).numpy()))
        plt.title(labels[i].numpy())
        plt.axis("off")
    if i_batch >= 3:
        break


In [None]:
# Let's visualize some batches of the train data
fig = plt.figure(figsize=(16, 16))
for i_batch, batch in enumerate(valid_loader):
    images, labels = batch["X"], batch["y"]
    plt.subplot(4, 4, i_batch + 1)
    plt.imshow(denormalize_image(images[0].permute(1, 2, 0).numpy()))
    plt.title(labels[0].numpy())
    plt.axis("off")
    if i_batch >= 15:
        break
        

In [None]:
def init_model_mobilenet_v2(n_classes):
    net = torch.hub.load("pytorch/vision:v0.6.0", "mobilenet_v2", pretrained=True)
    net.classifier = torch_nn.Linear(
        in_features=1280, 
        out_features=n_classes, 
        bias=True,
    )
    return net

def init_model_resnet18(n_classes):
    net = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
    net.classifier = torch_nn.Linear(
        in_features=512, 
        out_features=n_classes, 
        bias=True,
    )
    return net

def init_model_resnet101(n_classes):
    net = torch.hub.load('pytorch/vision:v0.6.0', 'resnet101', pretrained=True)
    net.classifier = torch_nn.Linear(
        in_features=2048, 
        out_features=n_classes, 
        bias=True,
    )
    return net

def init_model_vgg16(n_classes):
    net = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)
    net.classifier[6] = torch_nn.Linear(
        in_features=4096, 
        out_features=n_classes, 
        bias=True,
    )
    return net

def init_model_resnext50_32x4d(n_classes):
    net = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True)
    net.classifier = torch_nn.Linear(
        in_features=2048, 
        out_features=n_classes, 
        bias=True,
    )
    return net

In [None]:
class LossMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        # incremental update
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

        
class AccMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().astype(int)
        y_pred = y_pred.cpu().numpy().argmax(axis=1).astype(int)
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        # incremental update
        self.avg = true_count / self.n + last_n / self.n * self.avg

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion, 
        loss_meter, 
        score_meter
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.loss_meter = loss_meter
        self.score_meter = score_meter
        
        self.best_valid_score = -np.inf
        self.n_patience = 0
        
        self.messages = {
            "epoch": "[Epoch {}: {}] loss: {:.5f}, score: {:.5f}, time: {} s",
            "checkpoint": "The score improved from {:.5f} to {:.5f}. Save model to '{}'",
            "patience": "\nValid score didn't improve last {} epochs."
        }
    
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):
        history = {
            "train_loss": [],
            "train_score": [],
            "valid_loss": [],
            "valid_score": [],
        }
        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_score, train_time = self.train_epoch(train_loader)
            valid_loss, valid_score, valid_time = self.valid_epoch(valid_loader)
            
            history["train_loss"].append(train_loss)
            history["train_score"].append(train_score)
            history["valid_loss"].append(valid_loss)
            history["valid_score"].append(valid_score)
            
            self.info_message(
                self.messages["epoch"], "Train", n_epoch, train_loss, train_score, train_time
            )
            
            self.info_message(
                self.messages["epoch"], "Valid", n_epoch, valid_loss, valid_score, valid_time
            )
            
            if self.best_valid_score < valid_score:
                self.info_message(
                    self.messages["checkpoint"], self.best_valid_score, valid_score, save_path
                )
                self.best_valid_score = valid_score
                self.save_model(n_epoch, save_path)
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message(self.messages["patience"], patience)
                break
        
        return history
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        train_loss = self.loss_meter()
        train_score = self.score_meter()
        
        for step, batch in enumerate(train_loader, 1):
            images = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(images)

            loss = self.criterion(outputs, targets)
            loss.backward()

            train_loss.update(loss.detach().item())
            train_score.update(targets, outputs.detach())

            self.optimizer.step()
        
        return train_loss.avg, train_score.avg, int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        valid_loss = self.loss_meter()
        valid_score = self.score_meter()

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                images = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                valid_loss.update(loss.detach().item())
                valid_score.update(targets, outputs)
        
        return valid_loss.avg, valid_score.avg, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path):
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            save_path,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)


In [None]:
device = torch.device("mps" if torch.cuda.is_available() else "cpu")

if config["model"] == "mobilenet_v2":
    model = init_model_mobilenet_v2(n_classes)
elif config["model"] == "resnet18":
    model = init_model_resnet18(n_classes)
elif config["model"] == "resnet101":
    model = init_model_resnet101(n_classes)
elif config["model"] == "vgg16":
    model = init_model_vgg16(n_classes)
elif config["model"] == "resnext50_32x4d":
    model = init_model_resnext50_32x4d(n_classes)
    
model.to(device)

if config["optimizer"] == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=config["adam_lr"])

if config["criterion"] == "cross_entropy":
    criterion = torch_functional.cross_entropy

trainer = Trainer(
    model, 
    device, 
    optimizer, 
    criterion, 
    LossMeter, 
    AccMeter
)

history = trainer.fit(
    config["max_epochs"], 
    train_loader, 
    valid_loader, 
    config["model_save_path"], 
    config["patience_stop"],
)

In [None]:
# Visualize train and valid loss 
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='train loss')
plt.plot(history['valid_loss'], label='valid loss')
plt.xticks(fontsize=14)
plt.xlabel("Epoch number", fontsize=15)
plt.yticks(fontsize=14)
plt.ylabel("Loss value", fontsize=15)
plt.legend(fontsize=15)
plt.grid()

# Visualize train and valid accyracy 
plt.subplot(1, 2, 2)
plt.plot(history['train_score'], label='train acc')
plt.plot(history['valid_score'], label='valid acc')
plt.xticks(fontsize=14)
plt.xlabel("Epoch number", fontsize=15)
plt.yticks(fontsize=14)
plt.ylabel("Accuracy score", fontsize=15)
plt.legend(fontsize=15)
plt.grid();

In [None]:
# Load the best model
checkpoint = torch.load(config["model_save_path"])

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
best_valid_score = checkpoint["best_valid_score"]
n_epoch = checkpoint["n_epoch"]

model.eval()

print(f"Best model valid score: {best_valid_score} ({n_epoch} epoch)")

In [None]:
# Save the model predictions and true labels
y_pred = []
y_test = []
for batch in test_loader:
    y_pred.extend(model(batch['X'].to(device)).argmax(axis=-1).cpu().numpy())
    y_test.extend(batch['y'])

# Calculate needed metrics
test_accuracy = sk_metrics.accuracy_score(y_test, y_pred)
test_f1_macro = sk_metrics.f1_score(y_test, y_pred, average="macro")

print(f"Accuracy score on test data:\t{test_accuracy}")
print(f"Macro F1 score on test data:\t{test_f1_macro}")

In [None]:
# Load metadata to get classes people-friendly names
labels = df_metadata['minifigure_name'].tolist()

# Calculate confusion matrix
confusion_matrix = sk_metrics.confusion_matrix(y_test, y_pred)
# confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
df_confusion_matrix = pd.DataFrame(confusion_matrix, index=labels, columns=labels)

# Show confusion matrix
fig = plt.figure(figsize=(12, 12))
sn.heatmap(df_confusion_matrix, annot=True, cbar=False, cmap='Oranges', linewidths=1, linecolor='black')
plt.xlabel('Predicted labels', fontsize=15)
plt.xticks(fontsize=12)
plt.ylabel('True labels', fontsize=15)
plt.yticks(fontsize=12)

In [None]:
error_images = []
error_label = []
error_pred = []
error_prob = []
for batch in test_loader:
    _X_test, _y_test = batch['X'], batch['y']
    pred = torch.softmax(model(_X_test.to(device)), axis=-1).detach().cpu().numpy()
    pred_class = pred.argmax(axis=-1)
    if pred_class != _y_test.cpu().numpy():
        error_images.extend(_X_test)
        error_label.extend(_y_test)
        error_pred.extend(pred_class)
        error_prob.extend(pred.max(axis=-1))

In [None]:
fig = plt.figure(figsize=(16, 16))
for ind, image in enumerate(error_images):
    plt.subplot(math.ceil(len(error_images) / int(len(error_images) ** 0.5)), int(len(error_images) ** 0.5), ind + 1)
    plt.imshow(denormalize_image(image.permute(1, 2, 0).numpy()))
    plt.title(f"Predict: {labels[error_pred[ind]]} ({error_prob[ind]:.2f}) Real: {labels[error_label[ind]]}")
    plt.axis("off")

In [None]:
misclassified = np.where(np.array(y_pred) != np.array(y_test))[0]
misclassified_counts = sk_metrics.confusion_matrix(y_test, y_pred).sum(axis=1) - np.diag(sk_metrics.confusion_matrix(y_test, y_pred))
most_confused_classes = np.argsort(misclassified_counts)[-5:]
print("Most confused classes:", [labels[i] for i in most_confused_classes])

In [None]:
from torchvision.transforms import ToPILImage


test_dataset = test_loader.dataset 

def get_predictions_and_confidence(model, dataloader, device):
    model.eval()
    y_pred, y_test, confidence_scores = [], [], []
    
    with torch.no_grad():
        for batch in dataloader:
            batch_x = batch["X"].to(device)
            batch_y = batch["y"].cpu().numpy()

            outputs = model(batch_x)
            probs = torch.nn.functional.softmax(outputs, dim=-1)
            
            pred_classes = outputs.argmax(dim=-1).cpu().numpy()
            pred_confidences = probs.max(dim=-1).values.cpu().numpy()
            
            y_pred.extend(pred_classes)
            y_test.extend(batch_y)
            confidence_scores.extend(pred_confidences)
    
    return np.array(y_pred), np.array(y_test), np.array(confidence_scores)

# Get predictions, actual labels, and confidence scores
y_pred, y_test, confidence_scores = get_predictions_and_confidence(model, test_loader, device)

# Identify correct and misclassified indices
correct_indices = np.where(y_pred == y_test)[0]
incorrect_indices = np.where(y_pred != y_test)[0]

# Function to plot images
def plot_images(indices, title, num_images=10):
    plt.figure(figsize=(15, 6))
    
    for i, idx in enumerate(indices[:num_images]):
        img = test_dataset[idx]["X"]  # Fetch image
        true_label = labels[y_test[idx]]  # Get true label
        pred_label = labels[y_pred[idx]]  # Get predicted label
        confidence = confidence_scores[idx]  # Get confidence score

        plt.subplot(2, 5, i + 1)
        plt.imshow(ToPILImage()(img))  # Convert tensor to image
        plt.title(f"Pred: {pred_label} ({confidence:.2f})\nTrue: {true_label}", fontsize=10, color="green" if pred_label == true_label else "red")
        plt.axis("off")

    plt.suptitle(title, fontsize=14, fontweight="bold")
    plt.show()

# Plot correctly classified images
plot_images(correct_indices, "Correctly Classified Images")

# Plot misclassified images
plot_images(incorrect_indices, "Misclassified Images")


In [None]:
# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("mps")

# Define Generator
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.LeakyReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.LeakyReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.LeakyReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

# Hyperparameters
nz = 128  # Changed from 256 to 128 for stability
ngf = 64
ndf = 64
nc = 3
batch_size = 64  # Ensured consistent batch size

# Initialize models
netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)

# Loss and Optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.00005, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))

# Dataset Setup (Using your dataset structure)
dataset_root = '/Users/tejakolla/Documents/sem-2/Deep_learning/project-1-teja2002/archive'
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize images to [-1, 1]
])

# Loading the dataset with subfolders as different categories
dataset = datasets.ImageFolder(root=dataset_root, transform=transform)

# DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training DCGAN
num_epochs = 1000
real_label = 1.
fake_label = 0.
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

for epoch in range(num_epochs):
    for i, data in enumerate(train_loader, 0):
        # Train Discriminator
        netD.zero_grad()
        real_images = data[0].to(device)
        batch_size = real_images.size(0)
        labels = torch.full((batch_size,), real_label - 0.1, dtype=torch.float, device=device)  # Added label smoothing
        output = netD(real_images)
        lossD_real = criterion(output, labels)
        lossD_real.backward()

        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        labels.fill_(fake_label + 0.1)  # Added label smoothing for fake labels
        output = netD(fake_images.detach())
        lossD_fake = criterion(output, labels)
        lossD_fake.backward()
        optimizerD.step()

        # Train Generator
        netG.zero_grad()
        labels.fill_(real_label)
        output = netD(fake_images)
        lossG = criterion(output, labels)
        lossG.backward()
        optimizerG.step()

    # Generate sample images every 10 epochs
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            fake_images = netG(fixed_noise).detach().cpu()
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title(f"Generated Images - Epoch {epoch+1}")
        plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
        plt.show()

print("Training Complete!")

# Generate Synthetic LEGO Minifigure Images
def generate_synthetic_images(num_images):
    noise = torch.randn(num_images, nz, 1, 1, device=device)
    with torch.no_grad():
        fake_images = netG(noise).detach().cpu()
    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.title("Synthetic LEGO Minifigures")
    plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
    plt.show()

# Generate and display 16 synthetic images
generate_synthetic_images(16)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

# Define Generator
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Define Discriminator (Critic for WGAN)
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
    )


    def forward(self, input):
        return self.main(input).view(-1)

# Wasserstein Loss
def wasserstein_loss(output, target):
    return torch.mean(output * target)

# Gradient Penalty Calculation
def compute_gradient_penalty(D, real_samples, fake_samples):
    batch_size, c, h, w = real_samples.size()
    epsilon = torch.rand(batch_size, 1, 1, 1).expand_as(real_samples).to(device)
    interpolated = epsilon * real_samples + (1 - epsilon) * fake_samples
    interpolated.requires_grad_(True)

    prob_interpolated = D(interpolated)
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                    grad_outputs=torch.ones_like(prob_interpolated).to(device),
                                    create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Hyperparameters
nz = 128  # Latent vector size
ngf = 64
ndf = 64
nc = 3  # Number of channels in the images
batch_size = 64
num_epochs = 2000

# Initialize models
netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)

# Optimizers
optimizerD = optim.Adam(netD.parameters(), lr=0.00005, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.00002, betas=(0.5, 0.999))

# DataLoader Setup
dataset_root = '/Users/tejakolla/Documents/sem-2/Deep_learning/project-1-teja2002/archive'
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root=dataset_root, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training Loop for WGAN-GP
for epoch in range(num_epochs):
    for i, data in enumerate(train_loader, 0):
        real_images = data[0].to(device)
        batch_size = real_images.size(0)

        noise_std = 0.1  # Small Gaussian noise to stabilize training
        real_images += torch.randn_like(real_images) * noise_std
        fake_images += torch.randn_like(fake_images) * noise_std

        # Train Critic (Discriminator)
        netD.zero_grad()
        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = -torch.ones(batch_size, 1, device=device)
        
        # Real images
        output = netD(real_images)
        lossD_real = wasserstein_loss(output, real_labels)
        
        # Fake images
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        output = netD(fake_images.detach())  # Detach to avoid training the generator with the critic
        lossD_fake = wasserstein_loss(output, fake_labels)
        
        # Compute gradient penalty
        gradient_penalty = compute_gradient_penalty(netD, real_images, fake_images)
        
        # Total discriminator loss
        lossD = lossD_real + lossD_fake + 10 * gradient_penalty
        lossD.backward(retain_graph=True)  # Retain graph to use in the next backward pass
        optimizerD.step()
        
        # Train Generator (only once every n_critic steps)
        if i % 5 == 0:
            netG.zero_grad()
            output = netD(fake_images)  # Use the fresh fake_images
            lossG = wasserstein_loss(output, real_labels)
            lossG.backward()  # No need to retain graph here for generator's backward pass
            optimizerG.step()


    # Print and visualize results periodically
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}] - D Loss: {lossD.item()}, G Loss: {lossG.item()}")
        with torch.no_grad():
            fake_images = netG(fixed_noise).detach().cpu()
        plt.figure(figsize=(8, 8))
        plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
        plt.title(f"Generated Images - Epoch {epoch+1}")
        plt.axis("off")
        plt.show()

print("Training Complete!")
