In [None]:
from network import *
import torch.optim as optim
import random
import os
from PIL import Image
from torchvision import transforms, datasets
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    folder_path = "training_images"
    output_folder = "generated_images"
    os.makedirs("saved_model", exist_ok=True)
    save_path = os.path.join("saved_model", "model")
    png_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.jpg')]
    target_shape = (64, 64)
    seed_size = 100
    
    # Initialize the neuro networks
    disc = GAN_Dis().to(device)
    batch_size = 128
    optim_d = optim.Adam(disc.parameters(), lr=0.0002)
    gen = GAN_Gen().to(device)
    optim_g = optim.Adam(gen.parameters(), lr=0.0002)
    
    # Train parameters
    iterations = 50000
    
    crit = nn.BCELoss()
    for i in range(iterations):
        if i % 50 == 0:
            print("Iteration", i)
            torch.save(gen, save_path)

        # Discriminator training
        optim_d.zero_grad()
        seed = torch.normal(0, 1, (128, 100)).to(device)
        selected_files = random.sample(png_files, batch_size)
        real = load_image(folder_path, target_shape, selected_files)
        gen_sample = gen(seed)
        disc_loss = crit(disc(real), torch.ones(batch_size)) + crit(disc(gen_sample.detach()), torch.zeros(batch_size))
        disc_loss.backward()
        optim_d.step()
        
        # Generator Training
        optim_g.zero_grad()
        seed = torch.normal(0, 1, (128, 100)).to(device)
        gen_sample = gen(seed)
        gen_loss = crit(disc(gen_sample), torch.ones(batch_size))
        gen_loss.backward()
        optim_g.step()

def load_image(folder_path, target_shape, file_names):
    train_data = []
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    for f in file_names:
        p = os.path.join(folder_path, f)
        image = Image.open(p)
        tensor_image = transform(image)
        train_data.append(torch.unsqueeze(tensor_image, 0))
    cat_train = torch.cat(train_data, 0)
    return cat_train
        
train()