In [1]:
!pip install torch torch-geometric



In [2]:
# prompt: load a large directory of files

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
from torch_geometric.nn import GAE, GCN2Conv
import torch.nn.functional as F
from torch.nn import Linear
import os
import torch

num_features = 392
print(num_features)
hidden = 128
out_channels = 32

class GCNEncoder(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0):
        super().__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(Linear(input_channels, hidden_channels))
        self.lins.append(Linear(hidden_channels, output_channels))

        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(
                GCN2Conv(hidden_channels, alpha, theta, layer + 1,
                         shared_weights, normalize=False))

        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.dropout(x, self.dropout, training=self.training)
        x = x_0 = self.lins[0](x).relu()

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            x = conv(x, x_0, edge_index)
            x = x.relu()

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.lins[1](x)

        return x

encoder = GCNEncoder(input_channels=num_features, hidden_channels=hidden, output_channels=out_channels, num_layers=6, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.1)
model = GAE(encoder)
EPOCH = 1

MODEL_PATH = f"/content/drive/MyDrive/product_page_dataset/model/product_page_model_5_{EPOCH}.torch"

if os.path.exists(MODEL_PATH):
  print("loading model from", MODEL_PATH)
  model.load_state_dict(torch.load(MODEL_PATH))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)

print(model)

392
GAE(
  (encoder): GCNEncoder(
    (lins): ModuleList(
      (0): Linear(in_features=392, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=32, bias=True)
    )
    (convs): ModuleList(
      (0): GCN2Conv(128, alpha=0.1, beta=0.4054651081081644)
      (1): GCN2Conv(128, alpha=0.1, beta=0.22314355131420976)
      (2): GCN2Conv(128, alpha=0.1, beta=0.15415067982725836)
      (3): GCN2Conv(128, alpha=0.1, beta=0.11778303565638346)
      (4): GCN2Conv(128, alpha=0.1, beta=0.09531017980432493)
      (5): GCN2Conv(128, alpha=0.1, beta=0.08004270767353636)
    )
  )
  (decoder): InnerProductDecoder()
)


In [None]:
from pickle import load
import traceback
from torch_geometric.data import Dataset, download_url, Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import subgraph
import os
import numpy as np
import time

BATCHSIZE = 4

class ProductPageDataset(Dataset):
  def __init__(self, root, transform=None, pre_transform=None):
    self.names = [page for page in os.listdir("/content/drive/MyDrive/product_page_dataset/dataset/raw") if ".x" in page]
    super().__init__(root, transform, pre_transform)


  @property
  def raw_file_names(self):
    return self.names

  @property
  def processed_file_names(self):
    return []

  def download(self):
    pass

  def process(self):
    pass

  def len(self):
    return len(self.raw_file_names)

  def get(self, idx):
    f = os.path.join("/content/drive/MyDrive/product_page_dataset/dataset/raw", self.raw_file_names[idx])
    fe = os.path.join("/content/drive/MyDrive/product_page_dataset/dataset/raw", self.raw_file_names[idx].replace(".x", ".e"))
    mask = os.path.join("/content/drive/MyDrive/product_page_dataset/dataset/raw", self.raw_file_names[idx].replace(".x", ".mask"))
    try:
      with open(f, 'rb') as file:
        x = torch.load(file)
        with open(fe, 'rb') as filee:
          e = torch.load(filee).to(torch.int64)
          with (open(mask, 'rb')) as filem:
            mask = torch.load(filem)
            print(x.shape, e.shape, mask.shape, mask.dtype)
            return Data(x=x, edge_index=e, mask=mask)
    except:
      traceback.print_exc()
      print(f)
      return
dataset = ProductPageDataset(root='/content/drive/MyDrive/product_page_dataset/dataset/')
dataloader = DataLoader(dataset, batch_size=BATCHSIZE)

print(f"data loader {len(dataloader)}")

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# def rough_train():
#     model.train()
#     for chunk, sub_data in enumerate(dataloader):
#       data = sub_data.to(device)
#       z = model.encode(data.x, data.edge_index)
#       loss = model.recon_loss(z, data.edge_index)
#       loss.backward()
#       optimizer.step()
#       optimizer.zero_grad()
#       print(chunk, loss.item())
#     return float(loss)

# for i in range(10):
#   loss = rough_train()
# print("rough train", loss)



def smooth_train():
    overall_loss = 0
    optimizer.zero_grad()
    model.train()
    for chunk, sub_data in enumerate(dataloader):
      data = sub_data.to(device)
      z = model.encode(data.x, data.edge_index)

      subset_edge_index, _ = subgraph(data.mask, data.edge_index)

      loss = model.recon_loss(z, subset_edge_index)
      print(chunk, loss.item())
      overall_loss += loss.item() / len(dataloader)
      loss = loss / len(dataloader)
      loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return overall_loss

for epoch in range(EPOCH+1, 101):
    t0 = time.time()
    loss = smooth_train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Time: {time.time() - t0}')
    torch.save(model.state_dict(), f'/content/drive/MyDrive/product_page_dataset/model/product_page_model_4_{epoch}.torch')


  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
Processing...
Done!
  x = torch.load(file)


data loader 38


  e = torch.load(filee).to(torch.int64)
  mask = torch.load(filem)


torch.Size([94540, 392]) torch.Size([2, 94440]) torch.Size([94540]) torch.bool
torch.Size([105705, 392]) torch.Size([2, 105605]) torch.Size([105705]) torch.bool
torch.Size([134752, 392]) torch.Size([2, 134652]) torch.Size([134752]) torch.bool
torch.Size([69712, 392]) torch.Size([2, 69612]) torch.Size([69712]) torch.bool
0 1.3886240720748901
torch.Size([59785, 392]) torch.Size([2, 59685]) torch.Size([59785]) torch.bool
torch.Size([78358, 392]) torch.Size([2, 78258]) torch.Size([78358]) torch.bool
torch.Size([56106, 392]) torch.Size([2, 56006]) torch.Size([56106]) torch.bool
torch.Size([88961, 392]) torch.Size([2, 88861]) torch.Size([88961]) torch.bool
1 1.388566493988037
torch.Size([50591, 392]) torch.Size([2, 50491]) torch.Size([50591]) torch.bool
torch.Size([73065, 392]) torch.Size([2, 72965]) torch.Size([73065]) torch.bool
torch.Size([114014, 392]) torch.Size([2, 113914]) torch.Size([114014]) torch.bool
torch.Size([54662, 392]) torch.Size([2, 54562]) torch.Size([54662]) torch.bool
2 