In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Loading Data

Written by us

In [None]:
# Paths of image directories
train_imgs = '/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/train'
test_imgs = '/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/test'
train_annotations = '/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/train_data_5000images.csv'
test_annotations = '/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/test_data.csv'

In [None]:
import os
import pandas as pd
from torchvision.io import read_image
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

class HumanPoseDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
from torchvision import datasets, models, transforms

img_transform = transforms.Compose([transforms.CenterCrop(512),
                                    transforms.Resize(64),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToPILImage()
                                    # transforms.ToTensor(), # Scales data into [0,1]
                                    # transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
                                    ])

train_pose_dataset = HumanPoseDataset(annotations_file=train_annotations, img_dir=train_imgs, transform=img_transform)
test_pose_dataset = HumanPoseDataset(annotations_file=test_annotations, img_dir=test_imgs, transform=img_transform)

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt

def show_images(datset, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15))
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(img[0])

data = torchvision.datasets.Flowers102(root=".", download=True)
show_images(data)

### Forward Diffusion

Reused

In [None]:
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )

    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 500
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np

IMG_SIZE = 64
BATCH_SIZE = 128


def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.Flowers102(root=".", download=True,
                                         transform=data_transform)

    test = torchvision.datasets.Flowers102(root=".", download=True,
                                         transform=data_transform, split='test')
    return torch.utils.data.ConcatDataset([train, test])
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(len(dataloader.dataset))

### Backward Diffusion Models
V1 is pieced together from many sources

V2 is written

In [None]:
from torch import nn
import math

# SIMPLE DIFFUSION V1
class Up(nn.Module):
  def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
        self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

  def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        position = torch.arange(0, half_dim, device=device, dtype=torch.float32)
        angle_rates = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * -(math.log(10000.0) / self.dim))
        angles = time.unsqueeze(-1) * angle_rates

        # Use sine for the first half of the embeddings and cosine for the second half
        embeddings = torch.cat([angles.sin(), angles.cos()], dim=-1)

        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Down(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Up(up_channels[i], up_channels[i+1], \
                                        time_emb_dim) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

In [None]:
# SIMPLE DIFFUSION V2.1
from torch import nn
import math

a = torch.tensor([2])
r = torch.rsqrt(a).cuda() # 1/sqrt(2)

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super(Block, self).__init__()

        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)

        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x, t):
        # First Conv
        h = self.conv1(x)
        h = self.bnorm1(h)
        h = self.leaky_relu(h)

        # Time embedding
        time_emb = self.time_mlp(t)
        time_emb = time_emb.unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb

        # Second Conv
        h = self.conv2(h)
        h = self.bnorm2(h)
        h = self.leaky_relu(h)

        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        position = torch.arange(0, half_dim, device=device, dtype=torch.float32)
        angle_rates = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * -(math.log(10000.0) / self.dim))
        angles = time.unsqueeze(-1) * angle_rates

        # Use sine for the first half of the embeddings and cosine for the second half
        embeddings = torch.cat([angles.sin(), angles.cos()], dim=-1)

        return embeddings

class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """

    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64,)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop() * r
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('leaky_relu', 0.2))

model = SimpleUnet()
model.apply(weights_init)
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

In [None]:
# SIMPLE DIFFUSION V3, NOT WORKING
from torch import nn
import math
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.scale_factor = math.sqrt(d_model)

    def forward(self, Q, K, V):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale_factor
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False, attn_dim=32):
        super(Block, self).__init__()

        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)

        # Remove BatchNorm for increased speed (can be added back if needed)
        # self.bnorm1 = nn.BatchNorm2d(out_ch)

        self.attention = SelfAttention(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        # Remove BatchNorm for increased speed (can be added back if needed)
        # self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.conv1(x)
        # Remove BatchNorm for increased speed (can be added back if needed)
        # h = self.bnorm1(h)
        h = self.relu(h)

        # Time embedding
        time_emb = self.time_mlp(t)
        time_emb = time_emb.unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb

        # Scaled Dot-Product Attention
        Q = self.attention(h, h, h)
        h = h + Q  # Residual connection

        # Second Conv
        h = self.conv2(h)
        # Remove BatchNorm for increased speed (can be added back if needed)
        # h = self.bnorm2(h)
        h = self.relu(h)

        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        position = torch.arange(0, half_dim, device=device, dtype=torch.float32)
        angle_rates = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * -(math.log(10000.0) / self.dim))
        angles = time.unsqueeze(-1) * angle_rates

        # Use sine for the first half of the embeddings and cosine for the second half
        embeddings = torch.cat([angles.sin(), angles.cos()], dim=-1)

        return embeddings

class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def weights_init(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.xavier_uniform_(m.weight)

    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], time_emb_dim, attn_dim=32)
                                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True, attn_dim=32)
                                  for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

        self.apply(self.weights_init)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model


### Loss
Reused

In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()

### Patch Diffusion

In [None]:
# For sample
def sample_patchify(image, device="cuda"):
  image = torch.squeeze(image)
  h, w = image.size(1), image.size(2)
  rows = torch.arange(h, dtype=torch.long, device=device).repeat(h, 1)
  columns = torch.arange(w, dtype=torch.long, device=device).view(-1, 1).repeat(1, w)
  image_pos = torch.stack((rows, columns))

  return torch.concat((image, image_pos), ).unsqueeze(0)


# Patch Diffusion
def patchify(images, patch_size=32, device="cuda"):
  batch_size, resolution = images.size(0), images.size(2)
  h, w = images.size(2), images.size(3)
  th, tw = patch_size, patch_size
  # Randomly sample patch upper-left corner pixel
  if w == tw and h == th:
      i = torch.zeros((batch_size,), device=device).long()
      j = torch.zeros((batch_size,), device=device).long()
  else:
      i = torch.randint(0, h - th + 1, (batch_size,), device=device)
      j = torch.randint(0, w - tw + 1, (batch_size,), device=device)

  # Make a tensor of the indexes of the patch
  rows = torch.arange(th, dtype=torch.long, device=device) + i[:, None]
  columns = torch.arange(tw, dtype=torch.long, device=device) + j[:, None]
  images = images.to(device)
  images = images.permute(1, 0, 2, 3)
  images = images[:, torch.arange(batch_size)[:, None, None], rows[:, torch.arange(th)[:, None]],
            columns[:, None]]
  images = images.permute(1, 0, 2, 3)

  # Repeat along batch size
  x_pos = torch.arange(tw, dtype=torch.long, device=device).unsqueeze(0).repeat(th, 1).unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1)
  y_pos = torch.arange(th, dtype=torch.long, device=device).unsqueeze(1).repeat(1, tw).unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1)

  # Normalize to -1 to 1
  x_pos = x_pos + j.view(-1, 1, 1, 1)
  y_pos = y_pos + i.view(-1, 1, 1, 1)
  x_pos = (x_pos / (resolution - 1) - 0.5) * 2.
  y_pos = (y_pos / (resolution - 1) - 0.5) * 2.

  images_pos = torch.cat((x_pos, y_pos), dim=1)

  return images, images_pos


def patch_loss(model, batch, t, patch_size, resolution, device="cuda"):
  images, images_pos = patchify(batch, patch_size)
  # print(images.size())
  # print(images_pos.size())
  # loss = model(images, t, images_pos)
  loss = model.forward(images, t)
  return loss

### Training
We wrote this code

In [None]:
from torch.optim import Adam
from tqdm import tqdm
from statistics import mean
from torch.optim.lr_scheduler import ExponentialLR
import pickle
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.00125)
epochs = 100 # Try more!

bestLoss = None

lossOverTime = []

# LOAD_FLOWER_PATH = "/content/drive/MyDrive/EECS 442 final project/Resources/simple_diffusion_models/flowers2.pth"
SAVE_FLOWER_PATH = "/content/drive/MyDrive/EECS 442 final project/Resources/simple_diffusion_models/flowers3.pth"
SAVE_FLOWER_PATH_COLLAB_FOLDER = "//content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/simple_diffusion_models/flowers3.pth"

# model.load_state_dict(torch.load(FLOWER_PATH))
start_time = time.time()

for epoch in range(epochs):
    losses = []
    print("Epoch", epoch)
    for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
      optimizer.zero_grad()

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      loss = get_loss(model, batch[0], t)
      losses.append(loss.item())
      loss.backward()
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        # Open a file and use dump()
        # with open('/content/drive/MyDrive/EECS 442 final project/Resources/simple_diffusion_models/flower_losses.pkl', 'wb') as file:
        with open('/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/simple_diffusion_models/flower_losses.pkl', 'wb') as file:
            # A new file will be created
            pickle.dump(lossOverTime, file)
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        sample_plot_image()

    meanLoss = mean(losses)
    lossOverTime.append(meanLoss)
    print("Mean Loss:", meanLoss)
    if not bestLoss:
      bestLoss = meanLoss - 0.002
      continue
    if meanLoss < bestLoss:
      print("New Best Loss:", meanLoss)
      bestLoss = meanLoss
      torch.save(model.state_dict(), SAVE_FLOWER_PATH_COLLAB_FOLDER)


end_time = time.time()
training_time = end_time - start_time
print(f"Total training time: {training_time} seconds")

In [None]:
# Visualize results
model.load_state_dict(torch.load(SAVE_FLOWER_PATH_COLLAB_FOLDER))
sample_plot_image()

In [None]:
# Plot losses
import pickle
import matplotlib.pyplot as plt

# Load the pickle file
pickle_file_path = '/content/drive/MyDrive/Colab Notebooks/EECS 442 final project/Resources/simple_diffusion_models/flower_losses.pkl'
with open(pickle_file_path, 'rb') as file:
    loss_over_time = pickle.load(file)

# Plot the loss over time
plt.plot(loss_over_time, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.show()