# Differentiating Concepts and Instances for Knowledge Graph Embedding

## Xin Lv, Lei Hou, Juanzi Li, Zhiyuan Liu


### Summary

[Paper PDF](https://arxiv.org/pdf/1811.04588v1.pdf)

[Paper Code](https://github.com/davidlvxin/TransC)

The paper describes a novel knowledge graph embedding model named TransC with the purpose of embedding knowledge that keeps the differentiation of concepts and instances. The TransC approach models the concepts in the knowledge graph as a sphere in the semantic space and each instance as a vector in the same space. The relations between concepts and instances and between concepts and sub-concepts are modeled by their relative positions. The model was tested on dataset based on YAGO knowledge base and the results show that the model captures the semantic link between instanceOf and subClassOf relation, while other knowledge embedding models split the relation between concept and instance.

The approach proposed by the authors consists of defining a different loss functions to measure the relative positions in embedding space, and then jointly learn the representations of concepts, instances, and relations based on the translation-based models. The paper define tree differnt loss functions to measure the relative positions in embedding space, and then jointly learn the representations of concepts, instances, and relations based on the translation-based models. The first loss function handles the _instanceOf_ relation. A triple is valid if it is inside sphere _s_ : $f_e(i,c) = ||i-p||_2 - m$. The second loss function handles the _subClassOf_ relation. It measures the distance between centres of spheres *$s_1$* and *$s_2$* denoting the concepts: $d = ||p_i - p_j||_2$. If *$s_1$* is inside *$s_2$* the triple is valid. In case that *$s_1$* is separate from *$s_2$* the loss function is defined as $f_c(c_1, c_2) = ||p_1 - p_2||_2 + m_1 - m_2$; if the the two spheres intersect the loss function is the same; if *$s_2$* is inside *$s_1$* then the loss function is defined as $f_c(c_1, c_2) =m_1 - m_2$. The last loss function handles the relational triple representation and is defined as: $f_r(h,t) = ||h + r - t||_2^2$. Where can be concluded that if _i_ is in sphere *$s_1$* and *$s_1$* is in *$s_2$* then there is a relation between _i_ and *$s_2$*.

The goal of training TransC is to minimize the loss function, and iteratively update embeddings of concepts, instances, and concepts. Given that every triple in the knowledge graph is positive, a set of negative triples is constructed by changing (picking head or tail positions from a set of concepts) positive triples. During training a margin-based ranking loss for *instanceOf* triples is defined: $L_e = ∑_{ξ∈S_e} ∑_{ξ′∈S′_e} [γ_e + f_e(ξ) − f_e(ξ′)]_+$; for *subClassOf* triples a ranking loss is defined: $L_c = ∑_{ξ∈S_c} ∑_{ξ′∈S′_c} [γ_c + f_c(ξ) − f_c(ξ′)]_+$ and for relational triples a ranking loss is defined $L_l = ∑_{ξ∈S_l} ∑_{ξ′∈S′_l} [γ_l + f_l(ξ) − f_l(ξ′)]_+$. As as result the overall loss function is defined as linear combinations of these three functions: $L = L_e + L_c + L_l$. 

The method is evaluated on two tasks commonly used in knowledge graph embedding: link prediction and triple classification. For testing a dataset with triples from YAGO and newly created triples is composed. For the link predictions task, where the goals is to predict the missing h or t for a relational triple (h, r, t), the model proposes a ranked list of candidates for the head given the triple. For every test relational triple the head or tail is removed and replaced with all instances existed in knowledge graph. After that these instances are ranked in ascending order of distances calculated by loss function $f_r$.

 The classification task is a binary classification task with goal to predict if a triple is valid or not. For this a dataset with equal number of true and negative triples is created. The threshold $δ_r$ is defined for every relation. If the loss function is smaller than $δ_r$, it will be classified as positive.

The authors of the paper conclude that the proposed model TransC for knowledge embedding can handle instances, concepts, and relations in the same space to deal with the transitivity of isA relations. Results from experiments show that TransC outperforms previous translation-based models in most cases.

### Code Test

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
#import argparse
from collections import Counter
import pickle as pkl
import os
import time

In [2]:
def norm(x, pnorm=0):
    if pnorm == 1:
        return torch.sum(torch.abs(x), -1)
    else:
        return torch.sum(x**2,-1)

def normalize_emb(x):
    # return  x/float(length)
    veclen = torch.clamp_min_(torch.norm(x, 2, -1,keepdim=True), 1.0)
    ret = x/veclen
    return ret.detach()

def normalize_radius(x):
    return torch.clamp(x,min=-1.0,max=1.0)

In [3]:
class Dataset(object):
    def __init__(self, args):
        self.dataset_name = args["dataset"]
        self.args = args
        self.entity_num, self.entity2id = self.read_file(self.dataset_name, "instance2id")
        self.relation_num, self.relation2id = self.read_file(self.dataset_name, "relation2id")
        self.concept_num, self.concept2id = self.read_file(self.dataset_name, "concept2id")
        self.triple_num, self.triples = self.read_triples(self.dataset_name, "triple2id")

        self.fb_h, self.fb_t, self.fb_r = [], [], []
        self.relation_vec,self.entity_vec,self.concept_vec = [],[],[]
        self.relation_tmp, self.entity_tmp, self.concept_tmp = [], [], []
        self.concept_r, self.concept_r_tmp = [], []
        self.ok = {}
        self.subClassOf_ok = {}
        self.instanceOf_ok = {}
        self.subClassOf = []
        self.instanceOf = []
        self.instance_concept = [[] for i in range(self.entity_num)]
        self.concept_instance = [[] for i in range(self.concept_num)]
        self.sub_up_concept = [[] for i in range(self.concept_num)]
        self.up_sub_concept = [[] for i in range(self.concept_num)]


    def read_file(self, dataset,filename,split = 'Train'):
        with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
            L = file.readlines()
            num = int(L[0].strip())
            contents = [[x for x in line.strip().split()] for line in L[1:]]
        return num, contents

    def read_triples(self, dataset,filename,split = 'Train'):
        with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
            L = file.readlines()
            num = int(L[0].strip())
            contents = [[int(x) for x in line.strip().split()] for line in L[1:]]
        return num, contents

    def read_biples(self, dataset, filename,split = 'Train'):
        with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
            L = file.readlines()
            contents = [[int(x) for x in line.strip().split()] for line in L[1:]]
        return contents

    def addHrt(self, x, y, z):  # x: head ,y: tail, z:relation
        self.fb_h.append(x)
        self.fb_r.append(z)
        self.fb_t.append(y)
        if (x, z) not in self.ok:
            self.ok[(x, z)] = {y: 1}
        else:
            self.ok[(x, z)][y] = 1

    def addSubClassOf(self, sub, parent):
        self.subClassOf.append([sub, parent])
        self.subClassOf_ok[(sub, parent)] = 1

    def addInstanceOf(self, instance, concept):
        self.instanceOf.append([instance, concept])
        self.instanceOf_ok[(instance, concept)] = 1

    def setup(self):
        self.left_entity = [Counter() for i in range(self.relation_num)]
        self.right_entity = [Counter() for i in range(self.relation_num)]

        for h, t, r in self.triples:
            self.addHrt(h, t, r)
            if self.args["bern"]:
                self.left_entity[r][h] += 1
                self.right_entity[r][t] += 1

        self.left_num = [float(sum(c.values())) / float(len(c)) for c in self.left_entity]
        self.right_num = [float(sum(c.values())) / float(len(c)) for c in self.right_entity]

        self.instanceOf_contents = self.read_biples(self.args["dataset"], "instanceOf2id")
        self.subClassOf_contents = self.read_biples(self.args["dataset"], "subClassOf2id")

        for a, b in self.instanceOf_contents:
            self.addInstanceOf(a,b)
            self.instance_concept[a].append(b)
            self.concept_instance[b].append(a)

        for a, b in self.subClassOf_contents:
            self.addSubClassOf(a,b)
            self.sub_up_concept[a].append(b)
            self.up_sub_concept[b].append(a)


        self.instance_brother = [[ins for concept in concepts
                                for ins in self.concept_instance[concept]
                                if ins != instance_out]
                                for instance_out, concepts
                                in enumerate(self.instance_concept)]

        self.concept_brother = [[sub for up in ups
                                for sub in self.up_sub_concept[up]
                                if sub != sub_out]
                                for sub_out, ups
                                in enumerate(self.sub_up_concept)]

        self.trainSize = len(self.fb_h) + len(self.instanceOf) + len(self.subClassOf)

        print("train size {} {} {} {}".format(self.trainSize, len(self.fb_h),len(self.instanceOf),len(self.subClassOf)))

    def save(self,):
        with open("data/" + self.dataset_name + "/" + self.args["split"] + "/processed.pkl",'wb') as file:
            pkl.dump(self, file)

In [4]:
def load_processed(dataset_name,split):
    with open("data/" + dataset_name + "/" + split + "/processed.pkl",'rb') as file:
        res = pkl.load(file)
    return res

In [5]:
class Train(nn.Module):
    def __init__(self,args,dataset):
        super(Train, self).__init__()
        self.args = args
        self.D = dataset
        self.entity_vec = nn.Embedding(self.D.entity_num,args["emb_dim"])
        self.concept_vec = nn.Embedding(self.D.concept_num,args["emb_dim"]+1)
        self.relation_vec = nn.Embedding(self.D.relation_num,args["emb_dim"])
        self.optimizer = torch.optim.SGD(self.parameters(),lr=args["lr"])

        ## initialize
        nn.init.normal_(self.entity_vec.weight.data, 0.0, 1.0 / args["emb_dim"])
        nn.init.normal_(self.relation_vec.weight.data, 0.0, 1.0 / args["emb_dim"])
        nn.init.normal_(self.concept_vec.weight.data[:, :-1], 0.0, 1.0 / args["emb_dim"])
        nn.init.uniform_(self.concept_vec.weight.data[:, -1], 0.0, 1.0)

        # self.training_instance_file = open("data/cpp_training_instance.txt", 'r')
        # with open("data/cpp_training_instance.txt", 'r') as file:
        #     lines = file.readlines()
        #     lines = [line.strip().split("\t") for line in lines]
        #     self.training_instance = [[int(x) for x in line] for line in lines ]
        #     print("using saved instances")

    def doTrain(self):
        nbatches = self.args["nbatches"]
        nepoch = self.args["nepoch"]
        batchSize = int(self.D.trainSize / nbatches)
        allreadyindex = 0

        dis_a_L, dis_b_L = [], []
        dis_count = 0
        for epoch in range(nepoch):
            res = 0
            for batch in range(nbatches):
                losses = []
                stime = time.time()
                pairs = [[], [], []]

                #normalize
                self.entity_vec.weight.data = normalize_emb(self.entity_vec.weight.data)
                self.relation_vec.weight.data = normalize_emb(self.relation_vec.weight.data)
                self.concept_vec.weight.data[:, :-1] = normalize_emb(self.concept_vec.weight.data[:, :-1])
                self.concept_vec.weight.data[:, -1] = normalize_radius(self.concept_vec.weight.data[:, -1])

                self.optimizer.zero_grad()
                for k in range(batchSize):
                    i = random.randint(0, self.D.trainSize - 1)
                    if i < len(self.D.fb_r):
                        cut = 1 - epoch * self.args["hrt_cut"] / nepoch
                        pairs[0].append(self.trainHLR(i, cut))
                    elif i < len(self.D.fb_r) + len(self.D.instanceOf):
                        cut = 1 - epoch * self.args["ins_cut"] / nepoch
                        pairs[1].append(self.trainInstanceOf(i, cut))
                    else:
                        cut = 1 - epoch * self.args["sub_cut"] / nepoch
                        pairs[2].append(self.trainSubClassOf(i, cut))

                # for k in range(batchSize):
                #     line = self.training_instance_file.readline()
                #     line = line.strip().split("\t")
                #     instance = [int(x) for x in line]
                #     # print(instance)
                #     if instance[0] == -1:
                #         pairs[0].append(instance[1:])
                #     if instance[0] == -2:
                #         pairs[1].append(instance[1:])
                #     if instance[0] == -3:
                #         pairs[2].append(instance[1:])
                # allreadyindex += batchSize

                tensor_pairs= []
                for i in range(3):
                    tensor_pairs.append(torch.stack([torch.tensor(x) for x in list(zip(*pairs[i]))]).cuda())
                loss1,dis_a,dis_b = self.doTrainHLR(tensor_pairs[0])
                loss2 = self.doTrainInstanceOf(tensor_pairs[1])
                loss3 = self.doTrainSubClassOf(tensor_pairs[2])
                losses = loss1 + loss2 + loss3
                losses.backward()

                dis_a_L.append(torch.sqrt(dis_a).sum()), dis_b_L.append(torch.sqrt(dis_b).sum()) # for logs
                dis_count += dis_a.size(0)

                self.optimizer.step()
                res += losses.detach().cpu().numpy()

            print(sum(dis_a_L) / dis_count, sum(dis_b_L) / dis_count, dis_a.size())
            dis_a_L, dis_b_L = [], []
            dis_count = 0

            if epoch % 1 == 0:
                print("epoch:{} Res: {:.6f} Loss {:.6f},loss1: {:.6f},loss2: {:.6f},loss3 {:.6f}".format(epoch,res,losses,loss1,loss2,loss3))
            if epoch % 500 == 0 or epoch == nepoch - 1:
                entity_vec_save = self.entity_vec.weight.detach().cpu().numpy()
                concept_vec_save = self.concept_vec.weight.detach().cpu().numpy()
                relation_vec_save = self.relation_vec.weight.detach().cpu().numpy()

                # with open("embeddings/transc/"+self.args.version+"_embeddings_epoch" + str(epoch) + ".pkl", 'wb') as file:
                #    pkl.dump({"entity_vec": entity_vec_save,
                #             "concept_vec": concept_vec_save,
                #             "relation_vec":relation_vec_save},file)
                #print("saved!")

                #write for cpp test
                with open("vector/"+self.args["dataset"] +"/entity2vec.vec", 'w') as file:
                    for vec in entity_vec_save:
                        list_vec = list(vec)
                        str_vec = "\t".join([str(x) for x in list_vec])
                        file.write(str_vec+"\n")

                with open("vector/"+ self.args["dataset"] + "/relation2vec.vec", 'w') as file:
                    for vec in relation_vec_save:
                        list_vec = list(vec)
                        str_vec = "\t".join([str(x) for x in list_vec])
                        file.write(str_vec+"\n")

                with open("vector/" + self.args["dataset"]+"/concept2vec.vec", 'w') as file:
                    for vec in concept_vec_save:
                        list_vec = list(vec)
                        str_vec = "\t".join([str(x) for x in list_vec[:-1]])
                        file.write(str_vec + "\n" + str(list_vec[-1]) + "\n")
        # self.training_instance_file.close()


    def trainHLR(self, i, cut):
        pr = 0.5
        cur_fbr, cur_fbh, cur_fbt = self.D.fb_r[i], self.D.fb_h[i], self.D.fb_t[i]
        if self.args["bern"] == 1:
            pr = float(self.D.right_num[cur_fbr]) / (self.D.right_num[cur_fbr] + self.D.left_num[cur_fbr])
        if random.uniform(0, 1) < pr:
            loop=True
            while loop:

                if len(self.D.instance_brother[cur_fbt]) > 0:
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.entity_num - 1)
                    else:
                        j = random.randint(0, len(self.D.instance_brother[cur_fbt]) - 1)
                        j = self.D.instance_brother[cur_fbt][j]
                else:
                    j = random.randint(0, self.D.entity_num - 1)
                loop = j in self.D.ok[(cur_fbh, cur_fbr)]
            return cur_fbh, cur_fbt, cur_fbr, cur_fbh, j, cur_fbr
        else:
            loop=True
            while loop:
                if len(self.D.instance_brother[cur_fbh]) > 0:
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.entity_num - 1)
                    else:
                        j = random.randint(0, len(self.D.instance_brother[cur_fbh]) - 1)
                        j = self.D.instance_brother[cur_fbh][j]
                else:
                    j = random.randint(0, self.D.entity_num - 1)
                loop = ((j,cur_fbr) in self.D.ok) and (cur_fbt in self.D.ok[(j, cur_fbr)])
            return cur_fbh, cur_fbt, cur_fbr, j, cur_fbt, cur_fbr

    def trainInstanceOf(self, i, cut):
        i = i - len(self.D.fb_h)
        cur_ins,cur_cpt = self.D.instanceOf[i]
        if random.randint(0, 1) == 0:
            loop=True
            while loop:
                if len(self.D.instance_brother[cur_ins]) > 0: #
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.entity_num - 1)
                    else:
                        j = random.randint(0, len(self.D.instance_brother[cur_ins]) - 1)
                        j = self.D.instance_brother[cur_ins][j]
                else:
                    j = random.randint(0, self.D.entity_num - 1)
                loop = (j, cur_cpt) in self.D.instanceOf_ok
            return cur_ins, cur_cpt, j, cur_cpt

        else:
            loop=True
            while loop:
                if len(self.D.concept_brother[cur_cpt]) > 0: #
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.concept_num - 1)
                    else:
                        j = random.randint(0, len(self.D.concept_brother[cur_cpt]) - 1)
                        j = self.D.concept_brother[cur_cpt][j]
                else:
                    j = random.randint(0, self.D.concept_num - 1)
                loop = (cur_ins, j) in self.D.instanceOf_ok
            return cur_ins, cur_cpt, cur_ins, j

    def trainSubClassOf(self, i, cut):
        i = i - len(self.D.fb_h) - len(self.D.instanceOf)

        cur_cpth,cur_cptt=self.D.subClassOf[i]
        if random.randint(0, 1) == 0:
            loop=True
            while loop:
                if len(self.D.concept_brother[cur_cpth]) > 0: #
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.concept_num - 1)
                    else:
                        j = random.randint(0, len(self.D.concept_brother[cur_cpth]) - 1)
                        j = self.D.concept_brother[cur_cpth][j]
                else:
                    j = random.randint(0, self.D.concept_num - 1)
                loop = (j, cur_cptt) in self.D.subClassOf_ok
            return cur_cpth, cur_cptt, j, cur_cptt
        else:
            loop=True
            while loop:
                if len(self.D.concept_brother[cur_cptt]) > 0: #
                    if random.uniform(0, 1) < cut:
                        j = random.randint(0, self.D.concept_num - 1)
                    else:
                        j = random.randint(0, len(self.D.concept_brother[cur_cptt]) - 1)
                        j = self.D.concept_brother[cur_cptt][j]
                else:
                    j = random.randint(0, self.D.concept_num - 1)
                loop = (cur_cpth, j) in self.D.subClassOf_ok
            return cur_cpth, cur_cptt, cur_cpth, j

    def doTrainHLR(self, ids):
        entity_embs = self.entity_vec(ids[[0, 1, 3, 4], :])
        relation_embs = self.relation_vec(ids[[2, 5], :])

        dis_a = norm(entity_embs[0] + relation_embs[0] - entity_embs[1],pnorm=self.args["pnorm"])
        dis_b = norm(entity_embs[2] + relation_embs[1] - entity_embs[3],pnorm=self.args["pnorm"])

        loss = F.relu(dis_a + self.args["margin_hrt"] - dis_b).sum()
        return loss,dis_a,dis_b

    def doTrainInstanceOf(self, ids):
        entity_embs = self.entity_vec(ids[[0, 2], :])
        concept_embs = self.concept_vec(ids[[1, 3], :])
        radius = concept_embs[:, :, -1]
        concept_embs = concept_embs[:, :, :-1]

        if self.args["pnorm"]==1:
            dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args["pnorm"]) - torch.abs(radius))
        else:
            dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args["pnorm"]) - radius ** 2)

        loss = F.relu(dis[0] + self.args["margin_ins"] - dis[1]).sum()
        return loss

    def doTrainSubClassOf(self, ids):
        concept_embs_a = self.concept_vec(ids[[0,2],:])
        concept_embs_b = self.concept_vec(ids[[1, 3], :])
        radius_a = concept_embs_a[:, :, -1]
        radius_b = concept_embs_b[:, :, -1]

        concept_embs_a = concept_embs_a[:, :, :-1]
        concept_embs_b = concept_embs_b[:, :, :-1]

        if self.args["pnorm"]==1:
            dis = F.relu(norm(concept_embs_a - concept_embs_b,pnorm=self.args["pnorm"]) + torch.abs(radius_a) - torch.abs(radius_b))
        else:
            dis = F.relu(norm(concept_embs_a - concept_embs_b,pnorm=self.args["pnorm"]) + radius_a ** 2 - radius_b ** 2)

        loss = F.relu(dis[0] + self.args["margin_sub"] - dis[1]).sum()
        return loss

In [6]:
def read_file(dataset,filename,split = 'Train'):
    with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
        L = file.readlines()
        num = int(L[0].strip())
        contents = [[x for x in line.strip().split()] for line in L[1:]]
    return num, contents

In [7]:
def read_triples(dataset,filename,split = 'Train'):
    with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
        L = file.readlines()
        num = int(L[0].strip())
        contents = [[int(x) for x in line.strip().split()] for line in L[1:]]
    return num, contents

In [8]:
def read_biples(dataset, filename,split = 'Train'):
    with open("data/" + dataset + "/" + split+"/"+filename + ".txt") as file:
        L = file.readlines()
        contents = [[int(x) for x in line.strip().split()] for line in L[1:]]
    return contents

In [9]:
# def parseargs():
#     parsers = argparse.ArgumentParser()
#     parsers.add_argument("--emb_dim", type=int, default=100)
#     parsers.add_argument("--margin_hrt", type=float, default=1.0)
#     parsers.add_argument("--margin_ins", type=float, default=0.4)
#     parsers.add_argument("--margin_sub", type=float, default=0.3)
#     parsers.add_argument("--hrt_cut", type=float, default=0.8)
#     parsers.add_argument("--ins_cut", type=float, default=0.8)
#     parsers.add_argument("--sub_cut", type=float, default=0.8)

#     parsers.add_argument("--nepoch", type=float, default=1000)
#     parsers.add_argument("--nbatches", type=float, default=100)

#     parsers.add_argument("--lr", type=float, default=0.001)
#     parsers.add_argument("--bern", type=int, default=1)
#     parsers.add_argument("--pnorm", type=int, default=1)
#     parsers.add_argument("--dataset", type=str, default="YAGO39K")
#     parsers.add_argument("--split", type=str, default="Train")
#     parsers.add_argument("--version", type=str, default='tmp')

#     args = parsers.parse_args()
#     return args

Original code was using console inputs with argparse. Changed it to work with jupyter cells.

In [10]:
def parse_args(config=None):
    if config is None:
        config = {
            "emb_dim": 100,
            "margin_hrt": 1.0,
            "margin_ins": 0.4,
            "margin_sub": 0.3,
            "hrt_cut": 0.8,
            "ins_cut": 0.8,
            "sub_cut": 0.8,
            "nepoch": 1000,
            "nbatches": 100,
            "lr": 0.001,
            "bern": 1,
            "pnorm": 1,
            "dataset": "YAGO39K",
            "split": "Train",
            "version": "tmp",
        }
    return config

In [11]:
# def main():
#     args = parseargs()

#     if not os.path.exists("data/" + args.dataset + "/" + args.split + "/processed.pkl"):
#         dataset = Dataset(args=args)
#         dataset.setup()
#         dataset.save()
#     else:
#         dataset = load_processed(dataset_name=args.dataset, split=args.split)
#         print("dataset loaded")

#     train = Train(args = args,dataset= dataset).cuda()
#     train.doTrain()

In [12]:
def main(config=None):
    if config is None:
        config = parse_args()

    if not os.path.exists("data/" + config["dataset"] + "/" + config["split"] + "/processed.pkl"):
        dataset = Dataset(args=config)
        dataset.setup()
        dataset.save()
    else:
        dataset = load_processed(dataset_name=config["dataset"], split=config["split"])
        print("dataset loaded")

    train = Train(args=config, dataset=dataset).cuda()
    train.doTrain()

Was running out of RAM and had to change the main() function. Manually did the dataset.setup() and .save(). This produced 1.4GB file. On the run just told it to load the file.

In [13]:
# config = parse_args()
# dataset = Dataset(args=config)
# dataset.setup()
# dataset.save()

In [14]:
# config = parse_args()
# dataset = load_processed(dataset_name=config["dataset"], split=config["split"])
# print("dataset loaded")
# train = Train(args=config, dataset=dataset)
# train.doTrain()

In [17]:
main()

train size 822011 354996 437835 29180
tensor(2.2611, device='cuda:0', grad_fn=<DivBackward0>) tensor(2.4010, device='cuda:0', grad_fn=<DivBackward0>) torch.Size([3583])
epoch:0 Res: 349234.228516 Loss 2324.558594,loss1: 1057.178589,loss2: 1203.410400,loss3 63.969723
tensor(2.2419, device='cuda:0', grad_fn=<DivBackward0>) tensor(2.5067, device='cuda:0', grad_fn=<DivBackward0>) torch.Size([3499])
epoch:1 Res: 184564.315308 Loss 1721.437744,loss1: 750.440491,loss2: 922.411743,loss3 48.585434
tensor(2.2092, device='cuda:0', grad_fn=<DivBackward0>) tensor(2.5265, device='cuda:0', grad_fn=<DivBackward0>) torch.Size([3438])
epoch:2 Res: 144676.873535 Loss 1189.267578,loss1: 664.674255,loss2: 471.961700,loss3 52.631554
tensor(2.1948, device='cuda:0', grad_fn=<DivBackward0>) tensor(2.5429, device='cuda:0', grad_fn=<DivBackward0>) torch.Size([3587])
epoch:3 Res: 125863.120605 Loss 1042.218750,loss1: 640.308472,loss2: 341.972534,loss3 59.937706
tensor(2.1969, device='cuda:0', grad_fn=<DivBackward

In [20]:
# model = torch.load()
# torch.save(model.state_dict(), PATH)
# model.eval()