In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Generator network
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)


In [5]:
# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [13]:
# Solver
class Solver(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(Solver, self).__init__()
        self.generator = Generator(input_dim, output_dim, hidden_dim)
        self.discriminator = Discriminator(output_dim, hidden_dim)

    def forward(self, z):
        return self.generator(z)

In [22]:
input_dim = 128  # Size of the random noise vector
output_dim = 256  # Size of the solution for Burgers' equation
hidden_dim = 128

In [19]:
solver = Solver(input_dim, output_dim, hidden_dim)

In [23]:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(solver.discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(solver.generator.parameters(), lr=0.0002)

In [24]:
nIter = 1000
batch_size = 64

for epoch in range(nIter):
    for _ in range(batch_size):
        real_data = torch.randn(batch_size, input_dim)
        fake_data = solver.generator(real_data)

        # Train the discriminator
        d_optimizer.zero_grad()
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        outputs_real = solver.discriminator(real_data)
        outputs_fake = solver.discriminator(fake_data)

        d_loss_real = criterion(outputs_real, real_labels)
        d_loss_fake = criterion(outputs_fake, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        g_optimizer.zero_grad()
        fake_data = solver.generator(torch.randn(batch_size, input_dim))
        outputs = solver.discriminator(fake_data)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{nIter}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x256 and 128x128)