In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import os
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
import seaborn as sns
from tqdm import tqdm
import json

In [2]:
torch.manual_seed(42)
np.random.seed(42)

In [3]:
class ShapeDataset(Dataset):
    """Custom dataset for generating shapes with labels"""

    def __init__(self, num_samples=10000, img_size=64, transform=None):
        self.num_samples = num_samples
        self.img_size = img_size
        self.transform = transform
        self.classes = ['circle', 'square', 'triangle']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

    def __len__(self):
        return self.num_samples

    def generate_circle(self, img_size):
        """Generate a circle image"""
        img = Image.new('L', (img_size, img_size), 0)
        draw = ImageDraw.Draw(img)
        radius = np.random.randint(img_size//6, img_size//3)
        center_x = np.random.randint(radius, img_size - radius)
        center_y = np.random.randint(radius, img_size - radius)
        draw.ellipse([center_x - radius, center_y - radius,
                     center_x + radius, center_y + radius], fill=255)
        return img

    def generate_square(self, img_size):
        """Generate a square image"""
        img = Image.new('L', (img_size, img_size), 0)
        draw = ImageDraw.Draw(img)
        size = np.random.randint(img_size//4, img_size//2)
        x = np.random.randint(0, img_size - size)
        y = np.random.randint(0, img_size - size)
        draw.rectangle([x, y, x + size, y + size], fill=255)
        return img

    def generate_triangle(self, img_size):
        """Generate a triangle image"""
        img = Image.new('L', (img_size, img_size), 0)
        draw = ImageDraw.Draw(img)
        size = np.random.randint(img_size//4, img_size//2)
        x = np.random.randint(size//2, img_size - size//2)
        y = np.random.randint(size//2, img_size - size//2)
        points = [(x, y - size//2), (x - size//2, y + size//2), (x + size//2, y + size//2)]
        draw.polygon(points, fill=255)
        return img

    def __getitem__(self, idx):
        # Randomly select a class
        class_idx = np.random.randint(0, len(self.classes))
        class_name = self.classes[class_idx]

        # Generate corresponding shape
        if class_name == 'circle':
            img = self.generate_circle(self.img_size)
        elif class_name == 'square':
            img = self.generate_square(self.img_size)
        else:  # triangle
            img = self.generate_triangle(self.img_size)

        if self.transform:
            img = self.transform(img)

        return img, class_idx

In [4]:
class Generator(nn.Module):
    """Conditional Generator for CGAN"""

    def __init__(self, latent_dim=100, num_classes=3, img_channels=1, img_size=64):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_channels = img_channels
        self.img_size = img_size

        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, latent_dim)

        # Generator network
        self.model = nn.Sequential(
            # Input: latent_dim + latent_dim (for label)
            nn.Linear(latent_dim * 2, 256 * 8 * 8),
            nn.BatchNorm1d(256 * 8 * 8),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 8, 8)),

            # Upsampling layers
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),  # 64x64
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Embed labels
        label_embed = self.label_embedding(labels)

        # Concatenate noise and label embedding
        gen_input = torch.cat((noise, label_embed), dim=1)

        # Generate image
        img = self.model(gen_input)
        return img


In [5]:
class Discriminator(nn.Module):
    """Conditional Discriminator for CGAN"""

    def __init__(self, num_classes=3, img_channels=1, img_size=64):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes
        self.img_channels = img_channels
        self.img_size = img_size

        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, img_size * img_size)

        # Discriminator network
        self.model = nn.Sequential(
            # Input: img_channels + 1 (for embedded label)
            nn.Conv2d(img_channels + 1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Embed labels and reshape to match image dimensions
        label_embed = self.label_embedding(labels)
        label_embed = label_embed.view(label_embed.size(0), 1, self.img_size, self.img_size)

        # Concatenate image and label
        disc_input = torch.cat((img, label_embed), dim=1)

        # Get validity score
        validity = self.model(disc_input)
        return validity

In [6]:
class CGANTrainer:
    """Trainer class for Conditional GAN"""

    def __init__(self, latent_dim=100, num_classes=3, img_channels=1, img_size=64,
                 lr=0.0002, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_size = img_size

        # Initialize models
        self.generator = Generator(latent_dim, num_classes, img_channels, img_size).to(device)
        self.discriminator = Discriminator(num_classes, img_channels, img_size).to(device)

        # Loss function
        self.criterion = nn.BCELoss()

        # Optimizers
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

        # Training metrics
        self.train_history = {'g_loss': [], 'd_loss': [], 'accuracy': []}

    def train(self, dataloader, num_epochs=100, save_interval=10):
        """Train the CGAN"""
        print(f"Training on device: {self.device}")

        for epoch in range(num_epochs):
            epoch_g_loss = 0
            epoch_d_loss = 0
            correct_predictions = 0
            total_predictions = 0

            progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')

            for i, (real_imgs, real_labels) in enumerate(progress_bar):
                batch_size = real_imgs.size(0)
                real_imgs = real_imgs.to(self.device)
                real_labels = real_labels.to(self.device)

                # Create labels for real and fake data
                real_validity = torch.ones(batch_size, 1).to(self.device)
                fake_validity = torch.zeros(batch_size, 1).to(self.device)

                # ---------------------
                # Train Discriminator
                # ---------------------
                self.optimizer_D.zero_grad()

                # Real images
                real_pred = self.discriminator(real_imgs, real_labels)
                d_real_loss = self.criterion(real_pred, real_validity)

                # Fake images
                noise = torch.randn(batch_size, self.latent_dim).to(self.device)
                fake_labels = torch.randint(0, self.num_classes, (batch_size,)).to(self.device)
                fake_imgs = self.generator(noise, fake_labels)
                fake_pred = self.discriminator(fake_imgs.detach(), fake_labels)
                d_fake_loss = self.criterion(fake_pred, fake_validity)

                # Total discriminator loss
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                self.optimizer_D.step()

                # -----------------
                # Train Generator
                # -----------------
                self.optimizer_G.zero_grad()

                # Generate fake images and get discriminator's opinion
                fake_pred = self.discriminator(fake_imgs, fake_labels)
                g_loss = self.criterion(fake_pred, real_validity)

                g_loss.backward()
                self.optimizer_G.step()

                # Calculate accuracy (discriminator's ability to classify real vs fake)
                real_acc = (real_pred > 0.5).float().mean()
                fake_acc = (fake_pred <= 0.5).float().mean()
                batch_acc = (real_acc + fake_acc) / 2

                epoch_g_loss += g_loss.item()
                epoch_d_loss += d_loss.item()
                correct_predictions += batch_acc.item() * batch_size
                total_predictions += batch_size

                progress_bar.set_postfix({
                    'D_loss': f'{d_loss.item():.4f}',
                    'G_loss': f'{g_loss.item():.4f}',
                    'Acc': f'{batch_acc.item():.4f}'
                })

            # Calculate epoch metrics
            avg_g_loss = epoch_g_loss / len(dataloader)
            avg_d_loss = epoch_d_loss / len(dataloader)
            accuracy = correct_predictions / total_predictions

            self.train_history['g_loss'].append(avg_g_loss)
            self.train_history['d_loss'].append(avg_d_loss)
            self.train_history['accuracy'].append(accuracy)

            print(f'Epoch [{epoch+1}/{num_epochs}] - G_loss: {avg_g_loss:.4f}, '
                  f'D_loss: {avg_d_loss:.4f}, Accuracy: {accuracy:.4f}')

            # Save sample images
            if (epoch + 1) % save_interval == 0:
                self.save_sample_images(epoch + 1)

        print("Training completed!")

    def save_sample_images(self, epoch, num_samples=9):
        """Save sample generated images"""
        self.generator.eval()

        with torch.no_grad():
            # Generate samples for each class
            fig, axes = plt.subplots(3, 3, figsize=(10, 10))
            fig.suptitle(f'Generated Samples - Epoch {epoch}')

            for i, class_name in enumerate(['circle', 'square', 'triangle']):
                for j in range(3):
                    noise = torch.randn(1, self.latent_dim).to(self.device)
                    label = torch.tensor([i]).to(self.device)

                    fake_img = self.generator(noise, label)
                    fake_img = fake_img.cpu().squeeze().numpy()

                    axes[i, j].imshow(fake_img, cmap='gray')
                    axes[i, j].set_title(f'{class_name.capitalize()}')
                    axes[i, j].axis('off')

            os.makedirs('generated_samples', exist_ok=True)
            plt.tight_layout()
            plt.savefig(f'generated_samples/epoch_{epoch}.png')
            plt.close()

        self.generator.train()

    def evaluate_model(self, test_dataloader):
        """Evaluate the model and calculate metrics"""
        self.discriminator.eval()

        all_predictions = []
        all_labels = []

        with torch.no_grad():
            for real_imgs, real_labels in test_dataloader:
                real_imgs = real_imgs.to(self.device)
                real_labels = real_labels.to(self.device)

                # Get discriminator predictions
                predictions = self.discriminator(real_imgs, real_labels)

                # Convert to binary predictions (real vs fake)
                binary_preds = (predictions > 0.5).cpu().numpy().astype(int)
                real_labels_binary = np.ones(len(real_labels))  # All are real images

                all_predictions.extend(binary_preds.flatten())
                all_labels.extend(real_labels_binary)

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)

        # Confusion matrix
        cm = confusion_matrix(all_labels, all_predictions)

        print(f"\nModel Evaluation Results:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")

        # Plot confusion matrix
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.savefig('confusion_matrix.png')
        plt.show()

        return accuracy, precision, recall

    def plot_training_history(self):
        """Plot training history"""
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.plot(self.train_history['g_loss'], label='Generator Loss')
        plt.plot(self.train_history['d_loss'], label='Discriminator Loss')
        plt.title('Training Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.plot(self.train_history['accuracy'])
        plt.title('Discriminator Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')

        plt.subplot(1, 3, 3)
        # Show final accuracy
        final_acc = self.train_history['accuracy'][-1] if self.train_history['accuracy'] else 0
        plt.bar(['Final Accuracy'], [final_acc])
        plt.title('Final Model Accuracy')
        plt.ylabel('Accuracy')
        plt.ylim(0, 1)

        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.show()

    def save_model(self, path='models'):
        """Save the trained models"""
        os.makedirs(path, exist_ok=True)

        torch.save({
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'optimizer_G_state_dict': self.optimizer_G.state_dict(),
            'optimizer_D_state_dict': self.optimizer_D.state_dict(),
            'train_history': self.train_history
        }, os.path.join(path, 'cgan_checkpoint.pth'))

        # Save just the generator for inference
        torch.save(self.generator.state_dict(), os.path.join(path, 'generator.pth'))

        print(f"Models saved to {path}/")

    def generate_samples_by_class(self, class_name, num_samples=5):
        """Generate samples for a specific class"""
        self.generator.eval()

        class_to_idx = {'circle': 0, 'square': 1, 'triangle': 2}
        if class_name not in class_to_idx:
            print(f"Invalid class name. Choose from: {list(class_to_idx.keys())}")
            return

        class_idx = class_to_idx[class_name]

        with torch.no_grad():
            fig, axes = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))
            if num_samples == 1:
                axes = [axes]

            for i in range(num_samples):
                noise = torch.randn(1, self.latent_dim).to(self.device)
                label = torch.tensor([class_idx]).to(self.device)

                fake_img = self.generator(noise, label)
                fake_img = fake_img.cpu().squeeze().numpy()

                axes[i].imshow(fake_img, cmap='gray')
                axes[i].set_title(f'{class_name.capitalize()} {i+1}')
                axes[i].axis('off')

            plt.tight_layout()
            plt.savefig(f'generated_{class_name}_samples.png')
            plt.show()

        self.generator.train()

In [7]:
def main():
    """Main training function"""
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Data transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])

    # Create datasets
    train_dataset = ShapeDataset(num_samples=8000, transform=transform)
    test_dataset = ShapeDataset(num_samples=2000, transform=transform)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    # Initialize trainer
    trainer = CGANTrainer(device=device)

    # Train the model
    print("Starting CGAN training...")
    trainer.train(train_loader, num_epochs=20, save_interval=20)

    # Evaluate the model
    print("\nEvaluating model...")
    accuracy, precision, recall = trainer.evaluate_model(test_loader)

    # Plot training history
    trainer.plot_training_history()

    # Generate samples for each class
    print("\nGenerating sample images...")
    for class_name in ['circle', 'square', 'triangle']:
        trainer.generate_samples_by_class(class_name, num_samples=5)

    # Save the trained model
    trainer.save_model()

    # Save training metrics
    metrics = {
        'final_accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'training_history': trainer.train_history
    }

    with open('training_metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)

    print(f"\nTraining completed successfully!")
    print(f"Final Model Accuracy: {accuracy:.4f}")

    if accuracy >= 0.7:
        print("Model meets the 70% accuracy requirement!")
    else:
        print("Model accuracy is below 70%. Consider training longer or adjusting hyperparameters.")


In [None]:
if __name__ == "__main__":
    main()

Using device: cpu
Training samples: 8000
Test samples: 2000
Starting CGAN training...
Training on device: cpu


Epoch 1/20: 100%|██████████| 125/125 [06:08<00:00,  2.94s/it, D_loss=0.2498, G_loss=2.5015, Acc=0.9297]


Epoch [1/20] - G_loss: 2.4975, D_loss: 0.3246, Accuracy: 0.9159


Epoch 2/20: 100%|██████████| 125/125 [06:02<00:00,  2.90s/it, D_loss=0.1707, G_loss=3.6550, Acc=0.9766]


Epoch [2/20] - G_loss: 2.6607, D_loss: 0.3326, Accuracy: 0.9053


Epoch 3/20: 100%|██████████| 125/125 [06:02<00:00,  2.90s/it, D_loss=0.4709, G_loss=2.2745, Acc=0.6484]


Epoch [3/20] - G_loss: 2.8539, D_loss: 0.2742, Accuracy: 0.9242


Epoch 4/20:  42%|████▏     | 53/125 [02:33<03:26,  2.87s/it, D_loss=0.3669, G_loss=4.3628, Acc=1.0000]