Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an NLP loss term to the reward function and slightly changed the parameter passing pattern. #4

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
@@ -0,0 +1,5 @@
__pycache__
.*
!/.gitignore
dask-worker-space
exp
76 changes: 76 additions & 0 deletions data/political_democracy.csv
@@ -0,0 +1,76 @@
y1,y2,y3,y4,y5,y6,y7,y8,x1,x2,x3
2.5,0.0,3.333333,0.0,1.25,0.0,3.72636,3.333333,4.442651,3.637586,2.557615
1.25,0.0,3.333333,0.0,6.25,1.1,6.666666,0.736999,5.384495,5.062595,3.568079
7.5,8.8,9.999998,9.199991,8.75,8.094061,9.999998,8.211809,5.961005,6.25575,5.224433
8.9,8.8,9.999998,9.199991,8.907948,8.127979,9.999998,4.615086,6.285998,7.567863,6.267495
10.0,3.333333,9.999998,6.666666,7.5,3.333333,9.999998,6.666666,5.863631,6.818924,4.573679
7.5,3.333333,6.666666,6.666666,6.25,1.1,6.666666,0.3685,5.533389,5.135798,3.89227
7.5,3.333333,6.666666,6.666666,5.0,2.233333,8.271257,1.485166,5.308268,5.075174,3.316213
7.5,2.233333,9.999998,1.496333,6.25,3.333333,9.999998,6.666666,5.347108,4.85203,4.263183
2.5,3.333333,3.333333,3.333333,6.25,3.333333,3.333333,3.333333,5.521461,5.241747,4.115168
10.0,6.666666,9.999998,8.899991,8.75,6.666666,9.999998,10.0,5.828946,5.370638,4.446216
7.5,3.333333,9.999998,6.666666,8.75,3.333333,9.999998,6.666666,5.916202,6.423247,3.791545
7.5,3.333333,6.666666,6.666666,8.75,3.333333,6.666666,6.666666,5.398163,6.246107,4.535708
7.5,3.333333,9.999998,6.666666,7.5,3.333333,6.666666,10.0,6.622736,7.872074,4.906154
7.5,7.766664,9.999998,6.666666,7.5,0.0,9.999998,0.0,5.204007,5.225747,4.561047
7.5,9.999998,3.333333,10.0,7.5,6.666666,9.999998,10.0,5.509388,6.202536,4.586286
7.5,9.999998,9.999998,7.766666,7.5,1.1,6.666666,6.666666,5.26269,5.820083,3.948911
2.5,3.333333,6.666666,6.666666,5.0,1.1,6.666666,0.3685,4.70048,5.023881,4.394491
1.25,0.0,3.333333,3.333333,1.25,3.333333,3.333333,3.333333,5.209486,4.465908,4.510268
10.0,9.999998,9.999998,10.0,8.75,9.999998,9.999998,10.0,5.916202,6.732211,5.829084
7.5,3.333299,3.333333,6.666666,7.5,2.233299,6.666666,2.948164,6.523562,6.992096,6.424591
10.0,9.999998,9.999998,10.0,10.0,9.999998,9.999998,10.0,6.238325,6.746412,5.741711
1.25,0.0,0.0,0.0,2.5,0.0,0.0,0.0,5.976351,6.712956,5.948168
2.5,0.0,3.333333,3.333333,2.5,0.0,3.333333,3.333333,5.631212,5.937536,5.686755
7.5,6.666666,9.999998,10.0,7.5,6.666666,9.999998,7.766666,6.033086,6.09357,4.611429
8.5,9.999998,6.666666,6.666666,8.75,9.999998,7.351018,6.666666,6.196444,6.704414,5.475261
6.1,0.0,5.4,3.333333,0.0,0.0,4.696028,3.333333,4.248495,2.70805,1.74083
3.3,0.0,6.666666,3.333333,6.25,0.0,6.666666,3.333333,5.141664,4.564348,2.255134
2.9,3.333333,6.666666,3.333333,2.385559,0.0,3.177568,1.116666,4.174387,3.688879,3.046927
9.2,0.0,9.9,3.333333,7.60966,0.0,8.118828,3.333333,4.382027,2.890372,1.711279
6.9,0.0,6.666666,3.333333,4.226033,0.0,0.0,0.0,4.290459,1.609438,1.001674
2.9,0.0,3.333333,3.333333,5.0,0.0,3.333333,3.333333,4.934474,4.234107,1.418971
2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.850148,1.94591,2.345229
5.0,0.0,3.333333,3.333333,5.0,0.0,3.333333,3.333333,5.181784,4.394449,3.167167
5.0,0.0,9.999998,3.333333,0.0,0.0,3.333333,0.74437,5.062595,4.59512,3.83497
4.1,9.999998,4.7,6.666666,3.75,0.0,7.827667,6.666666,4.691348,4.143135,2.255134
6.3,9.999998,9.999998,6.666666,6.25,2.233333,6.666666,2.955702,4.248495,3.367296,3.217506
5.2,4.999998,6.6,3.333333,3.633403,1.1,3.314128,3.333333,5.56452,5.236442,2.677633
5.0,3.333333,6.4,6.666666,2.844997,0.0,4.429657,1.485166,4.727388,3.610918,1.418971
3.1,4.999998,4.2,5.0,3.75,0.0,6.164304,3.333333,4.143135,2.302585,1.418971
4.1,9.999998,6.666666,3.333333,5.0,0.0,4.938089,2.233333,4.317488,4.955827,4.249888
5.0,9.999998,6.666666,1.666666,5.0,0.0,6.666666,0.3685,5.141664,4.430817,3.046927
5.0,7.7,6.666666,8.399997,6.25,4.358243,9.999998,4.141377,4.488636,3.465736,2.013579
5.0,6.2,9.999998,6.060997,5.0,2.782771,6.666666,4.974739,4.615121,4.941642,2.255134
5.6,4.9,0.0,0.0,6.555647,4.055463,6.666666,3.821796,3.850148,2.397895,1.74083
5.7,4.8,0.0,0.0,6.555647,4.055463,0.0,0.0,3.970292,2.397895,1.050741
7.5,9.999998,7.9,6.666666,3.75,9.999998,7.631891,6.666666,3.78419,3.091042,2.113313
2.5,0.0,6.666666,3.333333,2.5,0.0,0.0,0.0,3.806662,2.079442,2.137561
8.9,9.999998,9.7,6.666666,5.0,9.999998,9.556024,6.666666,4.532599,3.610918,1.587802
7.6,0.0,10.0,0.0,5.0,1.1,6.666666,1.099999,5.117994,4.934474,3.83497
7.8,9.999998,6.666666,6.666666,5.0,3.333333,6.666666,6.666666,5.049856,5.111988,4.38149
2.5,0.0,6.666666,3.333333,5.0,0.0,6.666666,3.333333,5.393628,5.638355,4.169451
3.8,0.0,5.1,0.0,3.75,0.0,6.666666,1.485166,4.477337,3.931826,2.474671
5.0,3.333333,3.333333,2.233333,5.0,3.333333,6.666666,5.566663,5.257495,5.840642,5.001796
6.25,3.333333,9.999998,2.955702,6.25,5.566663,9.999998,6.666666,5.379897,5.505332,3.299937
1.25,0.0,3.333333,0.0,2.5,0.0,0.0,0.0,5.298317,6.274762,4.38149
1.25,0.0,4.7,0.736999,2.5,0.0,3.333333,3.333333,4.859812,5.669881,3.537416
1.25,0.0,6.666666,0.0,2.5,0.0,5.228375,0.0,4.969813,5.56452,4.510268
7.5,7.766664,9.999998,6.666666,7.5,3.333333,9.999998,6.666666,6.011267,6.253829,5.001796
2.5,0.0,6.666666,4.433333,5.0,0.0,6.666666,1.485166,5.075174,5.252273,5.350708
7.5,9.999998,9.999998,10.0,8.75,9.999998,9.999998,10.0,6.736967,7.125283,6.330518
1.25,0.0,0.0,0.0,1.25,0.0,0.0,0.0,5.225747,5.451038,3.167167
1.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.025352,1.791759,2.657972
2.5,0.0,0.0,0.0,0.0,0.0,6.666666,2.948164,4.234107,2.70805,2.474671
6.25,2.233299,6.666666,2.970332,3.75,3.333299,6.666666,3.333333,4.644391,5.56452,3.046927
7.5,9.999998,9.999998,10.0,7.5,9.999998,9.999998,10.0,4.418841,4.941642,3.380653
5.0,0.0,6.1,0.0,5.0,3.333333,9.999998,3.333333,4.26268,4.219508,4.368462
7.5,9.999998,9.999998,10.0,3.75,9.999998,9.999998,10.0,4.875197,4.70048,3.83497
4.9,2.233333,9.999998,0.0,5.0,0.0,3.621989,3.333333,4.189655,1.386294,1.418971
5.0,0.0,8.2,0.0,5.0,0.0,0.0,0.0,4.521789,4.127134,2.113313
2.9,3.333333,6.666666,3.333333,2.5,3.333333,6.666666,3.333333,4.65396,3.555348,1.881917
5.4,9.999998,6.666666,3.333333,3.75,6.666666,6.666666,1.485166,4.477337,3.091042,1.987909
7.5,8.8,9.999998,6.066666,7.5,6.666666,9.999998,6.666666,5.337538,5.631212,3.491004
7.5,7.0,9.999998,6.852998,7.5,6.34834,6.666666,7.508044,6.12905,6.403574,5.001796
10.0,6.666666,9.999998,10.0,10.0,6.666666,9.999998,10.0,5.003946,4.962845,3.976994
3.75,3.333333,0.0,0.0,1.25,3.333333,0.0,0.0,4.488636,4.89784,2.867566
19 changes: 19 additions & 0 deletions frontend/README.md
@@ -0,0 +1,19 @@
# Front-end

The PC port's front end design with Qt.



Quick file introduction:

- ` images/` : sources fold
- `mainwindow.cpp` : main part for the input canvas
- `diagramscene` & `diagramitem` & `diagramtextitem` & `arrow`: components of the canvas
- `outputwindow`: parse the output model and display it as picture



Notice:

1. Relationship between the two windows haven't been established yet, so please change the `w.show()` annotation in `main.cpp` to alter which window to show during the debug process.
2. The input canvas design is modified from Qt Official Example "diagramscene", check it out if there's some difficult understanding the code.
94 changes: 94 additions & 0 deletions frontend/arrow.cpp
@@ -0,0 +1,94 @@
#include "arrow.h"

#include <qmath.h>
#include <QPen>
#include <QPainter>

/*
* Arrow between two Diagram Items.
*
**/
Arrow::Arrow(DiagramItem *startItem, DiagramItem *endItem, QGraphicsItem *parent)
: QGraphicsLineItem(parent)
{
myStartItem = startItem;
myEndItem = endItem;
setFlag(QGraphicsItem::ItemIsSelectable, true);
myColor = Qt::black;
setPen(QPen(myColor, 2, Qt::SolidLine, Qt::RoundCap, Qt::RoundJoin));
}

QRectF Arrow::boundingRect() const
{
qreal extra = (pen().width() + 20) / 2.0;

return QRectF(line().p1(), QSizeF(line().p2().x() - line().p1().x(),
line().p2().y() - line().p1().y()))
.normalized()
.adjusted(-extra, -extra, extra, extra);
}

QPainterPath Arrow::shape() const
{
QPainterPath path = QGraphicsLineItem::shape();
path.addPolygon(arrowHead);
return path;
}

void Arrow::updatePosition()
{
QLineF line(mapFromItem(myStartItem, 0, 0), mapFromItem(myEndItem, 0, 0));
setLine(line);
}

void Arrow::paint(QPainter *painter, const QStyleOptionGraphicsItem *,
QWidget *)
{
if (myStartItem->collidesWithItem(myEndItem))
return;

QPen myPen = pen();
myPen.setColor(myColor);
qreal arrowSize = 20;
painter->setPen(myPen);
painter->setBrush(myColor);

QLineF centerLine(myStartItem->pos(), myEndItem->pos());
QPolygonF endPolygon = myEndItem->polygon();
QPointF p1 = endPolygon.first() + myEndItem->pos();
QPointF p2;
QPointF intersectPoint;
QLineF polyLine;
for (int i = 1; i < endPolygon.count(); ++i) {
p2 = endPolygon.at(i) + myEndItem->pos();
polyLine = QLineF(p1, p2);
QLineF::IntersectType intersectType =
polyLine.intersect(centerLine, &intersectPoint);
if (intersectType == QLineF::BoundedIntersection)
break;
p1 = p2;
}

setLine(QLineF(intersectPoint, myStartItem->pos()));

double angle = std::atan2(-line().dy(), line().dx());

QPointF arrowP1 = line().p1() + QPointF(sin(angle + M_PI / 3) * arrowSize,
cos(angle + M_PI / 3) * arrowSize);
QPointF arrowP2 = line().p1() + QPointF(sin(angle + M_PI - M_PI / 3) * arrowSize,
cos(angle + M_PI - M_PI / 3) * arrowSize);

arrowHead.clear();
arrowHead << line().p1() << arrowP1 << arrowP2;

painter->drawLine(line());
painter->drawPolygon(arrowHead);
if (isSelected()) {
painter->setPen(QPen(myColor, 1, Qt::DashLine));
QLineF myLine = line();
myLine.translate(0, 4.0);
painter->drawLine(myLine);
myLine.translate(0,-8.0);
painter->drawLine(myLine);
}
}
45 changes: 45 additions & 0 deletions frontend/arrow.h
@@ -0,0 +1,45 @@
#ifndef ARROW_H
#define ARROW_H

#include <QGraphicsLineItem>

#include "diagramitem.h"

QT_BEGIN_NAMESPACE
class QGraphicsPolygonItem;
class QGraphicsLineItem;
class QGraphicsScene;
class QRectF;
class QGraphicsSceneMouseEvent;
class QPainterPath;
QT_END_NAMESPACE


class Arrow : public QGraphicsLineItem
{
public:
enum { Type = UserType + 4 };

Arrow(DiagramItem *startItem, DiagramItem *endItem,
QGraphicsItem *parent = nullptr);

int type() const override { return Type; }
QRectF boundingRect() const override;
QPainterPath shape() const override;
void setColor(const QColor &color) { myColor = color; }
DiagramItem *startItem() const { return myStartItem; }
DiagramItem *endItem() const { return myEndItem; }

void updatePosition();

protected:
void paint(QPainter *painter, const QStyleOptionGraphicsItem *option, QWidget *widget = nullptr) override;

private:
DiagramItem *myStartItem;
DiagramItem *myEndItem;
QColor myColor;
QPolygonF arrowHead;
};

#endif // ARROW_H
104 changes: 104 additions & 0 deletions frontend/diagramitem.cpp
@@ -0,0 +1,104 @@
#include "diagramitem.h"
#include "arrow.h"

#include <QGraphicsScene>
#include <QGraphicsSceneContextMenuEvent>
#include <QMenu>
#include <QPainter>

/*
* Diagram Item that enables creating, moving and handles arrow changing.
*
**/
DiagramItem::DiagramItem(DiagramType diagramType, QMenu *contextMenu,
QGraphicsItem *parent)
: QGraphicsPolygonItem(parent)
{
myDiagramType = diagramType;
myContextMenu = contextMenu;

QPainterPath path;
switch (myDiagramType) {
case StartEnd:
path.moveTo(200, 50);
path.arcTo(150, 0, 50, 50, 0, 90);
path.arcTo(50, 0, 50, 50, 90, 90);
path.arcTo(50, 50, 50, 50, 180, 90);
path.arcTo(150, 50, 50, 50, 270, 90);
path.lineTo(200, 25);
myPolygon = path.toFillPolygon();
break;
case Conditional:
myPolygon << QPointF(-100, 0) << QPointF(0, 100)
<< QPointF(100, 0) << QPointF(0, -100)
<< QPointF(-100, 0);
break;
case Step:
myPolygon << QPointF(-100, -100) << QPointF(100, -100)
<< QPointF(100, 100) << QPointF(-100, 100)
<< QPointF(-100, -100);
break;
default:
myPolygon << QPointF(-120, -80) << QPointF(-70, 80)
<< QPointF(120, 80) << QPointF(70, -80)
<< QPointF(-120, -80);
break;
}

setPolygon(myPolygon);
setFlag(QGraphicsItem::ItemIsMovable, true);
setFlag(QGraphicsItem::ItemIsSelectable, true);
setFlag(QGraphicsItem::ItemSendsGeometryChanges, true);
}

void DiagramItem::removeArrow(Arrow *arrow)
{
int index = arrows.indexOf(arrow);

if (index != -1)
arrows.removeAt(index);
}

void DiagramItem::removeArrows()
{
foreach (Arrow *arrow, arrows) {
arrow->startItem()->removeArrow(arrow);
arrow->endItem()->removeArrow(arrow);
scene()->removeItem(arrow);
delete arrow;
}
}

void DiagramItem::addArrow(Arrow *arrow)
{
arrows.append(arrow);
}

QPixmap DiagramItem::image() const
{
QPixmap pixmap(250, 250);
pixmap.fill(Qt::transparent);
QPainter painter(&pixmap);
painter.setPen(QPen(Qt::black, 8));
painter.translate(125, 125);
painter.drawPolyline(myPolygon);
return pixmap;
}

void DiagramItem::contextMenuEvent(QGraphicsSceneContextMenuEvent *event)
{
scene()->clearSelection();
setSelected(true);
myContextMenu->exec(event->screenPos());
}

QVariant DiagramItem::itemChange(GraphicsItemChange change, const QVariant &value)
{
if (change == QGraphicsItem::ItemPositionChange) {
foreach (Arrow *arrow, arrows) {
arrow->updatePosition();
}
}

return value;
}