<a href="https://colab.research.google.com/github/shcDE/Vision_Paper_Review_Code/blob/main/%EC%9D%B4%EC%9A%B0%EC%A7%84%EA%B5%90%EC%88%98_%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84_%EC%BD%94%EB%93%9C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the generator and discriminator architectures
# Placeholder for actual neural network architectures
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),  # Assuming the data is the size of a flattened MNIST image
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize the networks
generator = Generator()
discriminator = Discriminator()

# Loss functions
adversarial_loss = nn.BCELoss()

# Optimizers, using Adam as specified in the algorithm
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
epochs = 200
batch_size = 64
latent_dim = 100

# Placeholder for the data loaders
# These would be the actual data loaders for the incomplete and complete data
# For now, we'll use random noise as a placeholder for the data
incomplete_data_loader = torch.utils.data.DataLoader(torch.randn(1000, 784), batch_size=batch_size)
complete_data_loader = torch.utils.data.DataLoader(torch.randn(1000, 784), batch_size=batch_size)

# Training loop with size adjustment for the last batch in each epoch
for epoch in range(epochs):
    for i, (incomplete_data, complete_data) in enumerate(zip(incomplete_data_loader, complete_data_loader)):
        # Adjust the size of valid and fake labels to match the batch size
        current_batch_size = incomplete_data.size(0)
        valid = torch.ones(current_batch_size, 1, requires_grad=False)
        fake = torch.zeros(current_batch_size, 1, requires_grad=False)

        # Configure input
        real_data = complete_data.type(torch.FloatTensor)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(current_batch_size, latent_dim)

        # Generate a batch of data
        generated_data = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(generated_data), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_data), valid)
        fake_loss = adversarial_loss(discriminator(generated_data.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Output training stats
        if i % 50 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(incomplete_data_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# Since we are running in a limited execution environment, we will stop the loop here.
print("Training would proceed...")


[Epoch 0/200] [Batch 0/16] [D loss: 0.6960928440093994] [G loss: 0.6949040293693542]
[Epoch 1/200] [Batch 0/16] [D loss: 0.4224514961242676] [G loss: 0.8487991094589233]
[Epoch 2/200] [Batch 0/16] [D loss: 0.4301900863647461] [G loss: 0.6969741582870483]
[Epoch 3/200] [Batch 0/16] [D loss: 0.5886889100074768] [G loss: 0.5880992412567139]
[Epoch 4/200] [Batch 0/16] [D loss: 0.5525980591773987] [G loss: 1.649226188659668]
[Epoch 5/200] [Batch 0/16] [D loss: 0.5157387852668762] [G loss: 1.9520888328552246]
[Epoch 6/200] [Batch 0/16] [D loss: 0.43164700269699097] [G loss: 1.4265207052230835]
[Epoch 7/200] [Batch 0/16] [D loss: 0.3870055675506592] [G loss: 1.2351629734039307]
[Epoch 8/200] [Batch 0/16] [D loss: 0.4173913300037384] [G loss: 0.883663535118103]
[Epoch 9/200] [Batch 0/16] [D loss: 0.4602462649345398] [G loss: 0.6586522459983826]
[Epoch 10/200] [Batch 0/16] [D loss: 0.44308745861053467] [G loss: 0.7645212411880493]
[Epoch 11/200] [Batch 0/16] [D loss: 0.5217755436897278] [G loss