In [52]:
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import torch.nn.functional as F
import torch.nn as nn

def get_data_loader(batch_size, position="samePos"):
    
    transform = transforms.Compose(
        [transforms.Resize((224,224)), transforms.ToTensor()])
    
    if position == "diffPos":
        data = torchvision.datasets.ImageFolder('watermarked_image/diffPos/1', \
                                                transform = transform)
    else:
        data = torchvision.datasets.ImageFolder('watermarked_image/samePos/1', \
                                                transform = transform)
    trainLen = int(0.6*len(data))
    valLen = int(0.2*len(data))
    testLen = int(0.2*len(data))
    
    train_data, val_data, test_data = torch.utils.data.random_split\
    (data, [trainLen, valLen, testLen])
    
    train_data_loader = torch.utils.data.DataLoader(train_data, \
                                                    batch_size=batch_size,\
                                                    shuffle = True, \
                                                    num_workers=1)
    val_data_loader = torch.utils.data.DataLoader(val_data, \
                                                    batch_size=batch_size,\
                                                    shuffle = True, \
                                                    num_workers=1)
    test_data_loader = torch.utils.data.DataLoader(test_data, \
                                                    batch_size=batch_size,\
                                                    shuffle = True, \
                                                    num_workers=1)
    
    return train_data_loader, val_data_loader, test_data_loader

# NOTE: THIS ONLY USES ONE FOLDER!
train_data_loader, val_data_loader, test_data_loader = get_data_loader(1)
print(len(train_data_loader), len(val_data_loader), len(test_data_loader))

825 275 275


In [59]:
class oriResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#         self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)

        return x

In [64]:
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.pretrained = models.resnet50(pretrained=True)
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(1000, 500, 7, dilation=2),
            nn.ReLU(),
            nn.ConvTranspose2d(500, 3, 5, dilation=2),
            )
    def forward(self, x):
        x = self.pretrained(x)
        print(x.fc.in_features)
        x = self.upsample(x)
        return x


In [53]:
def train(model, batch_size, num_epochs=5, learning_rate=1e-4, plot=False):
    
    train_loader, val_loader, test_loader = get_data_loader(batch_size)
    
    torch.manual_seed(42)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_accuracies = []
    val_accuracies = []
    train_loss = []
    val_loss = []
    
    for epoch in range(num_epochs):
        for data in train_loader:
            img, label = data
            
            print(img.shape)
            recon = model(img)
            print(recon.shape)
            
            loss = criterion(recon, img)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        for valdata in valid_loader:
            img, label = data
            recon = model(img)
            
            valloss = criterion(recon, valdata)
            valloss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        train_loss.append(loss.item())
        val_loss.append(valloss.item())
        trainacc = get_accuracy(model, train_loader)
        train_accuracies.append(trainacc)
        valacc = get_accuracy(model, valid_loader)
        val_accuracies.append(valacc)
            
        print("Epoch %d; Loss %f; Train Acc %f; Val Acc %f" % (
              epoch+1, loss, trainacc[-1], validacc[-1]))
        
    if plot == True:
        plt.title("Train vs Validation Loss")
        plt.plot(range(1,num_epochs+1), train_loss, label="Train")
        plt.plot(range(1,num_epochs+1), val_loss, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend(loc='best')
        plt.show()

        plt.title("Train vs Validation Accuracies")
        plt.plot(range(1,num_epochs+1), train_accuracies, label="Train")
        plt.plot(range(1,num_epochs+1), val_accuracies, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.legend(loc='best')
        plt.show()

In [68]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [72]:
resModel = ResNet()
pretrained = models.resnet50(pretrained=True)
print(pretrained)
pretrained.fc = Identity()
train_loader, val_loader, test_loader = get_data_loader(1)
for data in train_loader:
    img, label = data
    output = pretrained(img)
    print(output.shape)
# train(resModel, batch_size=32, learning_rate =0.01, num_epochs=30)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=F

torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])
torch.Size([1, 2048])


KeyboardInterrupt: 